package com.yahoo.tensor.functions;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.functions.Reduce;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;

/* loaded from: input_file:com/yahoo/tensor/functions/ReduceJoin.class */
public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final DoubleBinaryOperator combinator;
    private final Reduce.Aggregator aggregator;
    private final List<String> dimensions;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/ReduceJoin$MultiDimensionIterator.class */
    public static class MultiDimensionIterator {
        private final long[] bounds;
        private final long[] iterator;
        private long remaining;

        MultiDimensionIterator(TensorType tensorType) {
            this.bounds = new long[tensorType.dimensions().size()];
            this.iterator = new long[tensorType.dimensions().size()];
            for (int i = 0; i < this.bounds.length; i++) {
                this.bounds[i] = tensorType.dimensions().get(i).size().get().longValue();
            }
            reset();
        }

        public int length() {
            return this.iterator.length;
        }

        public boolean hasNext() {
            return this.remaining > 0;
        }

        public void reset() {
            this.remaining = 1L;
            for (int length = this.iterator.length - 1; length >= 0; length--) {
                this.iterator[length] = 0;
                this.remaining *= this.bounds[length];
            }
        }

        public void next() {
            for (int length = this.iterator.length - 1; length >= 0; length--) {
                long[] jArr = this.iterator;
                int i = length;
                jArr[i] = jArr[i] + 1;
                if (this.iterator[length] < this.bounds[length]) {
                    break;
                }
                this.iterator[length] = 0;
            }
            this.remaining--;
        }

        public String toString() {
            return Arrays.toString(this.iterator);
        }
    }

    public ReduceJoin(Reduce<NAMETYPE> reduce, Join<NAMETYPE> join) {
        this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
    }

    public ReduceJoin(TensorFunction<NAMETYPE> tensorFunction, TensorFunction<NAMETYPE> tensorFunction2, DoubleBinaryOperator doubleBinaryOperator, Reduce.Aggregator aggregator, List<String> list) {
        this.argumentA = tensorFunction;
        this.argumentB = tensorFunction2;
        this.combinator = doubleBinaryOperator;
        this.aggregator = aggregator;
        this.dimensions = List.copyOf(list);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argumentA, this.argumentB);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + list.size());
        }
        return new ReduceJoin(list.get(0), list.get(1), this.combinator, this.aggregator, this.dimensions);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Reduce(new Join(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.combinator), this.aggregator, this.dimensions);
    }

    @Override // com.yahoo.tensor.functions.CompositeTensorFunction, com.yahoo.tensor.functions.TensorFunction
    public final Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        Tensor evaluate = this.argumentA.evaluate(evaluationContext);
        Tensor evaluate2 = this.argumentB.evaluate(evaluationContext);
        TensorType build = new TensorType.Builder(evaluate.type(), evaluate2.type()).build();
        return canOptimize(evaluate, evaluate2) ? evaluate((IndexedTensor) evaluate, (IndexedTensor) evaluate2, build) : Reduce.evaluate(Join.evaluate(evaluate, evaluate2, build, this.combinator), this.dimensions, this.aggregator);
    }

    public boolean canOptimize(Tensor tensor, Tensor tensor2) {
        if (tensor.type().dimensions().isEmpty() || tensor2.type().dimensions().isEmpty() || !(tensor instanceof IndexedTensor) || !tensor.type().hasOnlyIndexedBoundDimensions() || !(tensor2 instanceof IndexedTensor) || !tensor2.type().hasOnlyIndexedBoundDimensions()) {
            return false;
        }
        TensorType dimensionsInCommon = dimensionsInCommon((IndexedTensor) tensor, (IndexedTensor) tensor2);
        if (this.dimensions.isEmpty()) {
            return tensor.type().dimensions().size() == dimensionsInCommon.dimensions().size() && tensor2.type().dimensions().size() == dimensionsInCommon.dimensions().size();
        }
        if (this.dimensions.size() != dimensionsInCommon.dimensions().size()) {
            return false;
        }
        Iterator<TensorType.Dimension> it = dimensionsInCommon.dimensions().iterator();
        while (it.hasNext()) {
            if (!this.dimensions.contains(it.next().name())) {
                return false;
            }
        }
        return true;
    }

    private Tensor evaluate(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType) {
        TensorType outputType = Reduce.outputType(tensorType, this.dimensions);
        if (reduceDimensionIsInnermost(indexedTensor, indexedTensor2)) {
            if (indexedTensor.type().dimensions().size() == 1 && indexedTensor2.type().dimensions().size() == 1) {
                return vectorVectorProduct(indexedTensor, indexedTensor2, outputType);
            }
            if (indexedTensor.type().dimensions().size() == 1 && indexedTensor2.type().dimensions().size() == 2) {
                return vectorMatrixProduct(indexedTensor, indexedTensor2, outputType, false);
            }
            if (indexedTensor.type().dimensions().size() == 2 && indexedTensor2.type().dimensions().size() == 1) {
                return vectorMatrixProduct(indexedTensor2, indexedTensor, outputType, true);
            }
            if (indexedTensor.type().dimensions().size() == 2 && indexedTensor2.type().dimensions().size() == 2) {
                return matrixMatrixProduct(indexedTensor, indexedTensor2, outputType);
            }
        }
        return evaluateGeneral(indexedTensor, indexedTensor2, outputType);
    }

    private Tensor vectorVectorProduct(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType) {
        if (indexedTensor.type().dimensions().size() != 1 || indexedTensor2.type().dimensions().size() != 1) {
            throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product");
        }
        IndexedTensor.BoundBuilder boundBuilder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(tensorType);
        long min = Math.min(indexedTensor.dimensionSizes().size(0), indexedTensor2.dimensionSizes().size(0));
        Reduce.ValueAggregator ofType = Reduce.ValueAggregator.ofType(this.aggregator);
        for (int i = 0; i < min; i++) {
            ofType.aggregate(this.combinator.applyAsDouble(indexedTensor.get(i), indexedTensor2.get(i)));
        }
        boundBuilder.cellByDirectIndex(0L, ofType.aggregatedValue());
        return boundBuilder.build();
    }

    private Tensor vectorMatrixProduct(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, boolean z) {
        if (indexedTensor.type().dimensions().size() != 1 || indexedTensor2.type().dimensions().size() != 2) {
            throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product");
        }
        IndexedTensor.BoundBuilder boundBuilder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(tensorType);
        DimensionSizes dimensionSizes = indexedTensor.dimensionSizes();
        DimensionSizes dimensionSizes2 = indexedTensor2.dimensionSizes();
        Reduce.ValueAggregator ofType = Reduce.ValueAggregator.ofType(this.aggregator);
        for (int i = 0; i < dimensionSizes2.size(0); i++) {
            ofType.reset();
            for (int i2 = 0; i2 < Math.min(dimensionSizes.size(0), dimensionSizes2.size(1)); i2++) {
                double d = indexedTensor.get(i2);
                double d2 = indexedTensor2.get((i * dimensionSizes2.size(1)) + i2);
                ofType.aggregate(z ? this.combinator.applyAsDouble(d2, d) : this.combinator.applyAsDouble(d, d2));
            }
            boundBuilder.cellByDirectIndex(i, ofType.aggregatedValue());
        }
        return boundBuilder.build();
    }

    private Tensor matrixMatrixProduct(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType) {
        if (indexedTensor.type().dimensions().size() != 2 || indexedTensor2.type().dimensions().size() != 2) {
            throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product");
        }
        IndexedTensor.BoundBuilder boundBuilder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(tensorType);
        DimensionSizes dimensionSizes = indexedTensor.dimensionSizes();
        DimensionSizes dimensionSizes2 = indexedTensor2.dimensionSizes();
        int intValue = tensorType.indexOfDimension(indexedTensor.type().dimensions().get(0).name()).get().intValue();
        int intValue2 = tensorType.indexOfDimension(indexedTensor2.type().dimensions().get(0).name()).get().intValue();
        long size = intValue < intValue2 ? dimensionSizes2.size(0) : 1L;
        long size2 = intValue2 < intValue ? dimensionSizes.size(0) : 1L;
        Reduce.ValueAggregator ofType = Reduce.ValueAggregator.ofType(this.aggregator);
        for (int i = 0; i < dimensionSizes.size(0); i++) {
            for (int i2 = 0; i2 < dimensionSizes2.size(0); i2++) {
                ofType.reset();
                for (int i3 = 0; i3 < Math.min(dimensionSizes.size(1), dimensionSizes2.size(1)); i3++) {
                    ofType.aggregate(this.combinator.applyAsDouble(indexedTensor.get((i * dimensionSizes.size(1)) + i3), indexedTensor2.get((i2 * dimensionSizes2.size(1)) + i3)));
                }
                boundBuilder.cellByDirectIndex((i * size) + (i2 * size2), ofType.aggregatedValue());
            }
        }
        return boundBuilder.build();
    }

    private Tensor evaluateGeneral(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType) {
        IndexedTensor.BoundBuilder boundBuilder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(tensorType);
        TensorType outputType = Reduce.outputType(indexedTensor.type(), this.dimensions);
        TensorType outputType2 = Reduce.outputType(indexedTensor2.type(), this.dimensions);
        TensorType dimensionsInCommon = dimensionsInCommon(indexedTensor, indexedTensor2);
        long[] strides = strides(indexedTensor.type());
        long[] strides2 = strides(indexedTensor2.type());
        long[] strides3 = strides(tensorType);
        int[] mapIndexes = Join.mapIndexes(outputType, indexedTensor.type());
        int[] mapIndexes2 = Join.mapIndexes(dimensionsInCommon, indexedTensor.type());
        int[] mapIndexes3 = Join.mapIndexes(outputType2, indexedTensor2.type());
        int[] mapIndexes4 = Join.mapIndexes(dimensionsInCommon, indexedTensor2.type());
        int[] mapIndexes5 = Join.mapIndexes(outputType, tensorType);
        int[] mapIndexes6 = Join.mapIndexes(outputType2, tensorType);
        MultiDimensionIterator multiDimensionIterator = new MultiDimensionIterator(dimensionsInCommon);
        Reduce.ValueAggregator ofType = Reduce.ValueAggregator.ofType(this.aggregator);
        MultiDimensionIterator multiDimensionIterator2 = new MultiDimensionIterator(outputType);
        while (multiDimensionIterator2.hasNext()) {
            MultiDimensionIterator multiDimensionIterator3 = new MultiDimensionIterator(outputType2);
            while (multiDimensionIterator3.hasNext()) {
                ofType.reset();
                multiDimensionIterator.reset();
                while (multiDimensionIterator.hasNext()) {
                    ofType.aggregate(this.combinator.applyAsDouble(indexedTensor.get(toDirectIndex(multiDimensionIterator2, multiDimensionIterator, strides, mapIndexes, mapIndexes2)), indexedTensor2.get(toDirectIndex(multiDimensionIterator3, multiDimensionIterator, strides2, mapIndexes3, mapIndexes4))));
                    multiDimensionIterator.next();
                }
                boundBuilder.cellByDirectIndex(toDirectIndex(multiDimensionIterator2, multiDimensionIterator3, strides3, mapIndexes5, mapIndexes6), ofType.aggregatedValue());
                multiDimensionIterator3.next();
            }
            multiDimensionIterator2.next();
        }
        return boundBuilder.build();
    }

    private long toDirectIndex(MultiDimensionIterator multiDimensionIterator, MultiDimensionIterator multiDimensionIterator2, long[] jArr, int[] iArr, int[] iArr2) {
        long j = 0;
        for (int i = 0; i < multiDimensionIterator.length(); i++) {
            j += jArr[iArr[i]] * multiDimensionIterator.iterator[i];
        }
        for (int i2 = 0; i2 < multiDimensionIterator2.length(); i2++) {
            j += jArr[iArr2[i2]] * multiDimensionIterator2.iterator[i2];
        }
        return j;
    }

    private long[] strides(TensorType tensorType) {
        long[] jArr = new long[tensorType.dimensions().size()];
        if (jArr.length > 0) {
            long j = 1;
            jArr[jArr.length - 1] = 1;
            for (int length = jArr.length - 2; length >= 0; length--) {
                jArr[length] = j * tensorType.dimensions().get(length + 1).size().get().longValue();
                j = jArr[length];
            }
        }
        return jArr;
    }

    private TensorType dimensionsInCommon(IndexedTensor indexedTensor, IndexedTensor indexedTensor2) {
        TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(indexedTensor.type(), indexedTensor2.type()));
        for (TensorType.Dimension dimension : indexedTensor.type().dimensions()) {
            for (TensorType.Dimension dimension2 : indexedTensor2.type().dimensions()) {
                if (dimension.name().equals(dimension2.name())) {
                    if (!dimension.size().isPresent()) {
                        builder.set(dimension);
                    } else if (dimension2.size().isPresent()) {
                        builder.set(dimension.size().get().longValue() < dimension2.size().get().longValue() ? dimension : dimension2);
                    } else {
                        builder.set(dimension2);
                    }
                }
            }
        }
        return builder.build();
    }

    private boolean reduceDimensionIsInnermost(Tensor tensor, Tensor tensor2) {
        List<String> list = this.dimensions;
        if (list.isEmpty()) {
            list = dimensionsInCommon((IndexedTensor) tensor, (IndexedTensor) tensor2).dimensions().stream().map((v0) -> {
                return v0.name();
            }).toList();
        }
        if (list.size() != 1) {
            return false;
        }
        String str = list.get(0);
        return tensor.type().indexOfDimension(str).orElseThrow(() -> {
            return new IllegalArgumentException("Reduce-Join dimension '" + str + "' missing in tensor A.");
        }).intValue() == tensor.type().dimensions().size() - 1 && tensor2.type().indexOfDimension(str).orElseThrow(() -> {
            return new IllegalArgumentException("Reduce-Join dimension '" + str + "' missing in tensor B.");
        }).intValue() >= tensor2.type().dimensions().size() - 1;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return "reduce_join(" + this.argumentA.toString(toStringContext) + ", " + this.argumentB.toString(toStringContext) + ", " + this.combinator + ", " + this.aggregator + Reduce.commaSeparatedNames(this.dimensions, toStringContext) + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("reduce_join", this.argumentA, this.argumentB, this.combinator, this.aggregator, this.dimensions);
    }
}
