package com.yahoo.tensor.functions;

import com.yahoo.nativec.PosixFAdvise;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Label;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.LabelCache;
import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:com/yahoo/tensor/functions/Concat.class */
public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final String dimension;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: com.yahoo.tensor.functions.Concat$1, reason: invalid class name */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow;

        static {
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$DimType[DimType.common.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$DimType[DimType.separate.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$DimType[DimType.concat.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow = new int[ConcatPlan.CombineHow.values().length];
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow[ConcatPlan.CombineHow.left.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow[ConcatPlan.CombineHow.right.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow[ConcatPlan.CombineHow.both.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow[ConcatPlan.CombineHow.concat.ordinal()] = 4;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$CellVector.class */
    public static class CellVector {
        ArrayList<Double> values = new ArrayList<>();

        CellVector() {
        }

        void setValue(int i, double d) {
            while (this.values.size() <= i) {
                this.values.add(Double.valueOf(0.0d));
            }
            this.values.set(i, Double.valueOf(d));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$CellVectorMap.class */
    public static class CellVectorMap {
        java.util.Map<TensorAddress, CellVector> map = new HashMap();

        CellVectorMap() {
        }

        CellVector lookupCreate(TensorAddress tensorAddress) {
            return this.map.computeIfAbsent(tensorAddress, tensorAddress2 -> {
                return new CellVector();
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$CellVectorMapMap.class */
    public static class CellVectorMapMap {
        java.util.Map<TensorAddress, CellVectorMap> map = new HashMap();

        CellVectorMapMap() {
        }

        CellVectorMap lookupCreate(TensorAddress tensorAddress) {
            return this.map.computeIfAbsent(tensorAddress, tensorAddress2 -> {
                return new CellVectorMap();
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$ConcatPlan.class */
    public static class ConcatPlan {
        final TensorType resultType;
        final String concatDimension;
        SplitHow splitInfoA = new SplitHow();
        SplitHow splitInfoB = new SplitHow();
        List<CombineHow> combineHow = new ArrayList();

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:com/yahoo/tensor/functions/Concat$ConcatPlan$CombineHow.class */
        public enum CombineHow {
            left,
            right,
            both,
            concat
        }

        void aOnly(String str) {
            if (str.equals(this.concatDimension)) {
                this.splitInfoA.handleDims.add(DimType.concat);
                this.combineHow.add(CombineHow.concat);
            } else {
                this.splitInfoA.handleDims.add(DimType.separate);
                this.combineHow.add(CombineHow.left);
            }
        }

        void bOnly(String str) {
            if (str.equals(this.concatDimension)) {
                this.splitInfoB.handleDims.add(DimType.concat);
                this.combineHow.add(CombineHow.concat);
            } else {
                this.splitInfoB.handleDims.add(DimType.separate);
                this.combineHow.add(CombineHow.right);
            }
        }

        void bothAandB(String str) {
            if (str.equals(this.concatDimension)) {
                this.splitInfoA.handleDims.add(DimType.concat);
                this.splitInfoB.handleDims.add(DimType.concat);
                this.combineHow.add(CombineHow.concat);
            } else {
                this.splitInfoA.handleDims.add(DimType.common);
                this.splitInfoB.handleDims.add(DimType.common);
                this.combineHow.add(CombineHow.both);
            }
        }

        ConcatPlan(TensorType tensorType, TensorType tensorType2, String str) {
            this.resultType = TypeResolver.concat(tensorType, tensorType2, str);
            this.concatDimension = str;
            List<TensorType.Dimension> dimensions = tensorType.dimensions();
            List<TensorType.Dimension> dimensions2 = tensorType2.dimensions();
            int i = 0;
            int i2 = 0;
            while (i < dimensions.size() && i2 < dimensions2.size()) {
                String name = dimensions.get(i).name();
                String name2 = dimensions2.get(i2).name();
                int compareTo = name.compareTo(name2);
                if (compareTo == 0) {
                    bothAandB(name);
                    i++;
                    i2++;
                } else if (compareTo < 0) {
                    aOnly(name);
                    i++;
                } else {
                    bOnly(name2);
                    i2++;
                }
            }
            while (i < dimensions.size()) {
                int i3 = i;
                i++;
                aOnly(dimensions.get(i3).name());
            }
            while (i2 < dimensions2.size()) {
                int i4 = i2;
                i2++;
                bOnly(dimensions2.get(i4).name());
            }
            if (this.combineHow.size() < this.resultType.rank()) {
                this.combineHow.add(this.resultType.indexOfDimension(str).get().intValue(), CombineHow.concat);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$DimType.class */
    public enum DimType {
        common,
        separate,
        concat
    }

    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$Helper.class */
    static class Helper {
        ConcatPlan plan;
        Tensor result;

        Helper(Tensor tensor, Tensor tensor2, String str) {
            this.plan = new ConcatPlan(tensor.type(), tensor2.type(), str);
            this.result = merge(decompose(tensor, this.plan.splitInfoA), decompose(tensor2, this.plan.splitInfoB));
        }

        static int concatDimensionSize(CellVectorMapMap cellVectorMapMap) {
            HashSet hashSet = new HashSet();
            cellVectorMapMap.map.forEach((tensorAddress, cellVectorMap) -> {
                cellVectorMap.map.forEach((tensorAddress, cellVector) -> {
                    hashSet.add(Integer.valueOf(cellVector.values.size()));
                });
            });
            if (hashSet.isEmpty()) {
                return 1;
            }
            if (hashSet.size() == 1) {
                return ((Integer) hashSet.iterator().next()).intValue();
            }
            throw new IllegalArgumentException("inconsistent size of concat dimension, had " + hashSet.size() + " different values");
        }

        TensorAddress combine(TensorAddress tensorAddress, TensorAddress tensorAddress2, TensorAddress tensorAddress3, int i) {
            Label[] labelArr = new Label[this.plan.resultType.rank()];
            int i2 = 0;
            int i3 = 0;
            int i4 = 0;
            int i5 = 0;
            for (ConcatPlan.CombineHow combineHow : this.plan.combineHow) {
                switch (AnonymousClass1.$SwitchMap$com$yahoo$tensor$functions$Concat$ConcatPlan$CombineHow[combineHow.ordinal()]) {
                    case 1:
                        int i6 = i2;
                        i2++;
                        int i7 = i4;
                        i4++;
                        labelArr[i6] = tensorAddress2.objectLabel(i7);
                        break;
                    case 2:
                        int i8 = i2;
                        i2++;
                        int i9 = i5;
                        i5++;
                        labelArr[i8] = tensorAddress3.objectLabel(i9);
                        break;
                    case 3:
                        int i10 = i2;
                        i2++;
                        int i11 = i3;
                        i3++;
                        labelArr[i10] = tensorAddress.objectLabel(i11);
                        break;
                    case PosixFAdvise.POSIX_FADV_DONTNEED /* 4 */:
                        int i12 = i2;
                        i2++;
                        labelArr[i12] = LabelCache.GLOBAL.getOrCreateLabel(i);
                        break;
                    default:
                        throw new IllegalArgumentException("cannot handle: " + combineHow);
                }
            }
            return TensorAddressAny.ofUnsafe(labelArr);
        }

        Tensor merge(CellVectorMapMap cellVectorMapMap, CellVectorMapMap cellVectorMapMap2) {
            Tensor.Builder of = Tensor.Builder.of(this.plan.resultType);
            int concatDimensionSize = concatDimensionSize(cellVectorMapMap);
            for (Map.Entry<TensorAddress, CellVectorMap> entry : cellVectorMapMap.map.entrySet()) {
                TensorAddress key = entry.getKey();
                if (cellVectorMapMap2.map.containsKey(key)) {
                    CellVectorMap value = entry.getValue();
                    CellVectorMap cellVectorMap = cellVectorMapMap2.map.get(key);
                    value.map.forEach((tensorAddress, cellVector) -> {
                        cellVectorMap.map.forEach((tensorAddress, cellVector) -> {
                            for (int i = 0; i < cellVector.values.size(); i++) {
                                of.cell(combine(key, tensorAddress, tensorAddress, i), cellVector.values.get(i).doubleValue());
                            }
                            for (int i2 = 0; i2 < cellVector.values.size(); i2++) {
                                of.cell(combine(key, tensorAddress, tensorAddress, i2 + concatDimensionSize), cellVector.values.get(i2).doubleValue());
                            }
                        });
                    });
                }
            }
            return of.build();
        }

        CellVectorMapMap decompose(Tensor tensor, SplitHow splitHow) {
            Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
            Label[] labelArr = new Label[(int) splitHow.numCommon()];
            Label[] labelArr2 = new Label[(int) splitHow.numSeparate()];
            CellVectorMapMap cellVectorMapMap = new CellVectorMapMap();
            while (cellIterator.hasNext()) {
                Tensor.Cell next = cellIterator.next();
                TensorAddress key = next.getKey();
                long j = 0;
                int i = 0;
                int i2 = 0;
                for (int i3 = 0; i3 < splitHow.handleDims.size(); i3++) {
                    switch (splitHow.handleDims.get(i3)) {
                        case common:
                            int i4 = i;
                            i++;
                            labelArr[i4] = key.objectLabel(i3);
                            break;
                        case separate:
                            int i5 = i2;
                            i2++;
                            labelArr2[i5] = key.objectLabel(i3);
                            break;
                        case concat:
                            j = key.numericLabel(i3);
                            break;
                        default:
                            throw new IllegalArgumentException("cannot handle: " + splitHow.handleDims.get(i3));
                    }
                }
                cellVectorMapMap.lookupCreate(TensorAddressAny.ofUnsafe(labelArr)).lookupCreate(TensorAddressAny.ofUnsafe(labelArr2)).setValue((int) j, next.getValue().doubleValue());
            }
            return cellVectorMapMap;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Concat$SplitHow.class */
    public static class SplitHow {
        List<DimType> handleDims = new ArrayList();

        SplitHow() {
        }

        long numCommon() {
            return this.handleDims.stream().filter(dimType -> {
                return dimType == DimType.common;
            }).count();
        }

        long numSeparate() {
            return this.handleDims.stream().filter(dimType -> {
                return dimType == DimType.separate;
            }).count();
        }
    }

    public Concat(TensorFunction<NAMETYPE> tensorFunction, TensorFunction<NAMETYPE> tensorFunction2, String str) {
        Objects.requireNonNull(tensorFunction, "The first argument tensor cannot be null");
        Objects.requireNonNull(tensorFunction2, "The second argument tensor cannot be null");
        Objects.requireNonNull(str, "The dimension cannot be null");
        this.argumentA = tensorFunction;
        this.argumentB = tensorFunction2;
        this.dimension = str;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argumentA, this.argumentB);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Concat must have 2 arguments, got " + list.size());
        }
        return new Concat(list.get(0), list.get(1), this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Concat(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return "concat(" + this.argumentA.toString(toStringContext) + ", " + this.argumentB.toString(toStringContext) + ", " + toStringContext.resolveBinding(this.dimension) + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("concat", this.argumentA, this.argumentB, this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        return TypeResolver.concat(this.argumentA.type(typeContext), this.argumentB.type(typeContext), typeContext.resolveBinding(this.dimension));
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        Tensor evaluate = this.argumentA.evaluate(evaluationContext);
        Tensor evaluate2 = this.argumentB.evaluate(evaluationContext);
        return ((evaluate instanceof IndexedTensor) && (evaluate2 instanceof IndexedTensor)) ? oldEvaluate(evaluate, evaluate2) : new Helper(evaluate, evaluate2, this.dimension).result;
    }

    private Tensor oldEvaluate(Tensor tensor, Tensor tensor2) {
        TensorType concat = TypeResolver.concat(tensor.type(), tensor2.type(), this.dimension);
        Tensor ensureIndexedDimension = ensureIndexedDimension(this.dimension, tensor, concat.valueType());
        Tensor ensureIndexedDimension2 = ensureIndexedDimension(this.dimension, tensor2, concat.valueType());
        IndexedTensor indexedTensor = (IndexedTensor) ensureIndexedDimension;
        IndexedTensor indexedTensor2 = (IndexedTensor) ensureIndexedDimension2;
        Tensor.Builder of = Tensor.Builder.of(concat, concatSize(concat, indexedTensor, indexedTensor2, this.dimension));
        long longValue = ((Long) indexedTensor.type().indexOfDimension(this.dimension).map(num -> {
            return Long.valueOf(indexedTensor.dimensionSizes().size(num.intValue()));
        }).orElseThrow(RuntimeException::new)).longValue();
        int[] mapIndexes = mapIndexes(ensureIndexedDimension.type(), concat);
        int[] mapIndexes2 = mapIndexes(ensureIndexedDimension2.type(), concat);
        concatenateTo(indexedTensor, indexedTensor2, longValue, concat, mapIndexes, mapIndexes2, of);
        concatenateTo(indexedTensor2, indexedTensor, 0L, concat, mapIndexes2, mapIndexes, of);
        return of.build();
    }

    private void concatenateTo(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, long j, TensorType tensorType, int[] iArr, int[] iArr2, Tensor.Builder builder) {
        Set<String> set = (Set) indexedTensor.type().dimensionNames().stream().filter(str -> {
            return !str.equals(this.dimension);
        }).collect(Collectors.toSet());
        Iterator<IndexedTensor.SubspaceIterator> subspaceIterator = indexedTensor.subspaceIterator(set);
        while (subspaceIterator.hasNext()) {
            IndexedTensor.SubspaceIterator next = subspaceIterator.next();
            TensorAddress address = next.address();
            Iterator<IndexedTensor.SubspaceIterator> subspaceIterator2 = indexedTensor2.subspaceIterator(set);
            while (subspaceIterator2.hasNext()) {
                IndexedTensor.SubspaceIterator next2 = subspaceIterator2.next();
                while (next2.hasNext()) {
                    Tensor.Cell next3 = next2.next();
                    TensorAddress combineAddresses = combineAddresses(address, iArr, next3.getKey(), iArr2, tensorType, j, this.dimension);
                    if (combineAddresses != null) {
                        builder.cell(combineAddresses, next3.getValue().doubleValue());
                    }
                }
                next.reset();
            }
        }
    }

    private Tensor ensureIndexedDimension(String str, Tensor tensor, TensorType.Value value) {
        Optional<TensorType.Dimension> dimension = tensor.type().dimension(str);
        if (dimension.isPresent()) {
            if (dimension.get().isIndexed()) {
                return tensor;
            }
            throw new IllegalArgumentException("Concat in dimension '" + str + "' requires that dimension to be indexed or absent, but got a tensor with type " + tensor.type());
        }
        if (tensor.type().hasMappedDimensions()) {
            throw new IllegalArgumentException("Concat requires an indexed tensor, but got a tensor with type " + tensor.type());
        }
        return tensor.multiply(Tensor.Builder.of(new TensorType.Builder(value).indexed(str, 1L).build()).cell(1.0f, 0).build());
    }

    private DimensionSizes concatSize(TensorType tensorType, IndexedTensor indexedTensor, IndexedTensor indexedTensor2, String str) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); i++) {
            String name = tensorType.dimensions().get(i).name();
            long longValue = ((Long) indexedTensor.type().indexOfDimension(name).map(num -> {
                return Long.valueOf(indexedTensor.dimensionSizes().size(num.intValue()));
            }).orElse(0L)).longValue();
            long longValue2 = ((Long) indexedTensor2.type().indexOfDimension(name).map(num2 -> {
                return Long.valueOf(indexedTensor2.dimensionSizes().size(num2.intValue()));
            }).orElse(0L)).longValue();
            if (name.equals(str)) {
                builder.set(i, longValue + longValue2);
            } else if (longValue == 0 || longValue2 == 0 || longValue == longValue2) {
                builder.set(i, Math.max(longValue, longValue2));
            } else {
                builder.set(i, Math.min(longValue, longValue2));
            }
        }
        return builder.build();
    }

    private TensorAddress combineAddresses(TensorAddress tensorAddress, int[] iArr, TensorAddress tensorAddress2, int[] iArr2, TensorType tensorType, long j, String str) {
        long[] jArr = new long[tensorType.dimensions().size()];
        Arrays.fill(jArr, -1L);
        int intValue = tensorType.indexOfDimension(str).get().intValue();
        mapContent(tensorAddress, jArr, iArr, intValue, j);
        if (mapContent(tensorAddress2, jArr, iArr2, intValue, j)) {
            return TensorAddress.of(jArr);
        }
        return null;
    }

    private int[] mapIndexes(TensorType tensorType, TensorType tensorType2) {
        int[] iArr = new int[tensorType.dimensions().size()];
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            iArr[i] = tensorType2.indexOfDimension(tensorType.dimensions().get(i).name()).orElse(-1).intValue();
        }
        return iArr;
    }

    private boolean mapContent(TensorAddress tensorAddress, long[] jArr, int[] iArr, int i, long j) {
        for (int i2 = 0; i2 < tensorAddress.size(); i2++) {
            int i3 = iArr[i2];
            if (i == i3) {
                jArr[i3] = tensorAddress.numericLabel(i2) + j;
            } else {
                if (jArr[i3] != -1 && jArr[i3] != tensorAddress.numericLabel(i2)) {
                    return false;
                }
                jArr[i3] = tensorAddress.numericLabel(i2);
            }
        }
        return true;
    }
}
