package com.yahoo.tensor.functions;

import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Label;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.LabelCache;
import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;

/* loaded from: input_file:com/yahoo/tensor/functions/Join.class */
public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final DoubleBinaryOperator combinator;
    private static final PartialAddress empty = new PartialAddress.Builder(0).build();

    public Join(TensorFunction<NAMETYPE> tensorFunction, TensorFunction<NAMETYPE> tensorFunction2, DoubleBinaryOperator doubleBinaryOperator) {
        Objects.requireNonNull(tensorFunction, "The first argument tensor cannot be null");
        Objects.requireNonNull(tensorFunction2, "The second argument tensor cannot be null");
        Objects.requireNonNull(doubleBinaryOperator, "The combinator function cannot be null");
        this.argumentA = tensorFunction;
        this.argumentB = tensorFunction2;
        this.combinator = doubleBinaryOperator;
    }

    public static TensorType outputType(TensorType tensorType, TensorType tensorType2) {
        try {
            return TypeResolver.join(tensorType, tensorType2);
        } catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Can not join " + tensorType + " and " + tensorType2, e);
        }
    }

    public DoubleBinaryOperator combinator() {
        return this.combinator;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argumentA, this.argumentB);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Join must have 2 arguments, got " + list.size());
        }
        return new Join(list.get(0), list.get(1), this.combinator);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Join(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.combinator);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return "join(" + this.argumentA.toString(toStringContext) + ", " + this.argumentB.toString(toStringContext) + ", " + this.combinator + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("join", this.argumentA, this.argumentB, this.combinator);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        return outputType(this.argumentA.type(typeContext), this.argumentB.type(typeContext));
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        Tensor evaluate = this.argumentA.evaluate(evaluationContext);
        Tensor evaluate2 = this.argumentB.evaluate(evaluationContext);
        return evaluate(evaluate, evaluate2, outputType(evaluate.type(), evaluate2.type()), this.combinator);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Tensor evaluate(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        return (hasSingleIndexedDimension(tensor) && hasSingleIndexedDimension(tensor2) && tensor.type().dimensions().get(0).name().equals(tensor2.type().dimensions().get(0).name())) ? indexedVectorJoin((IndexedTensor) tensor, (IndexedTensor) tensor2, tensorType, doubleBinaryOperator) : (tensorType.dimensions().size() == tensor.type().dimensions().size() && tensorType.dimensions().size() == tensor2.type().dimensions().size()) ? singleSpaceJoin(tensor, tensor2, tensorType, doubleBinaryOperator) : tensor.type().dimensions().containsAll(tensor2.type().dimensions()) ? subspaceJoin(tensor2, tensor, tensorType, true, doubleBinaryOperator) : tensor2.type().dimensions().containsAll(tensor.type().dimensions()) ? subspaceJoin(tensor, tensor2, tensorType, false, doubleBinaryOperator) : generalJoin(tensor, tensor2, tensorType, doubleBinaryOperator);
    }

    private static boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private static Tensor indexedVectorJoin(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        int min = (int) Math.min(indexedTensor.dimensionSizes().size(0), indexedTensor2.dimensionSizes().size(0));
        Iterator<Double> valueIterator = indexedTensor.valueIterator();
        Iterator<Double> valueIterator2 = indexedTensor2.valueIterator();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(tensorType, new DimensionSizes.Builder(1).set(0, min).build());
        for (int i = 0; i < min; i++) {
            of.cell(doubleBinaryOperator.applyAsDouble(valueIterator.next().doubleValue(), valueIterator2.next().doubleValue()), i);
        }
        return of.build();
    }

    private static Tensor singleSpaceJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            TensorAddress key = next.getKey();
            Double asDouble = tensor2.getAsDouble(key);
            if (asDouble != null) {
                of.cell(key, doubleBinaryOperator.applyAsDouble(next.getValue().doubleValue(), asDouble.doubleValue()));
            }
        }
        return of.build();
    }

    private static Tensor subspaceJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, boolean z, DoubleBinaryOperator doubleBinaryOperator) {
        return ((tensor instanceof IndexedTensor) && (tensor2 instanceof IndexedTensor)) ? indexedSubspaceJoin((IndexedTensor) tensor, (IndexedTensor) tensor2, tensorType, z, doubleBinaryOperator) : generalSubspaceJoin(tensor, tensor2, tensorType, z, doubleBinaryOperator);
    }

    private static Tensor indexedSubspaceJoin(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, boolean z, DoubleBinaryOperator doubleBinaryOperator) {
        if (indexedTensor.isEmpty() || indexedTensor2.isEmpty()) {
            return Tensor.Builder.of(tensorType, new DimensionSizes.Builder(tensorType.dimensions().size()).build()).build();
        }
        DimensionSizes joinedSize = joinedSize(tensorType, indexedTensor, indexedTensor2);
        IndexedTensor.Builder builder = (IndexedTensor.Builder) Tensor.Builder.of(tensorType, joinedSize);
        HashSet hashSet = new HashSet(indexedTensor2.type().dimensionNames());
        hashSet.removeAll(indexedTensor.type().dimensionNames());
        Iterator<IndexedTensor.SubspaceIterator> subspaceIterator = indexedTensor2.subspaceIterator(hashSet, joinedSize);
        while (subspaceIterator.hasNext()) {
            IndexedTensor.SubspaceIterator next = subspaceIterator.next();
            joinSubspaces(indexedTensor.valueIterator(), indexedTensor.size(), next, next.size(), z, builder, doubleBinaryOperator);
        }
        return builder.build();
    }

    private static void joinSubspaces(Iterator<Double> it, long j, Iterator<Tensor.Cell> it2, long j2, boolean z, IndexedTensor.Builder builder, DoubleBinaryOperator doubleBinaryOperator) {
        int min = (int) Math.min(j, j2);
        if (z) {
            for (int i = 0; i < min; i++) {
                Tensor.Cell next = it2.next();
                builder.cell(next, doubleBinaryOperator.applyAsDouble(next.getValue().doubleValue(), it.next().doubleValue()));
            }
            return;
        }
        for (int i2 = 0; i2 < min; i2++) {
            Tensor.Cell next2 = it2.next();
            builder.cell(next2, doubleBinaryOperator.applyAsDouble(it.next().doubleValue(), next2.getValue().doubleValue()));
        }
    }

    private static DimensionSizes joinedSize(TensorType tensorType, IndexedTensor indexedTensor, IndexedTensor indexedTensor2) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); i++) {
            String name = tensorType.dimensions().get(i).name();
            Optional<Integer> indexOfDimension = indexedTensor.type().indexOfDimension(name);
            Optional<Integer> indexOfDimension2 = indexedTensor2.type().indexOfDimension(name);
            if (indexOfDimension.isPresent() && indexOfDimension2.isPresent()) {
                builder.set(i, Math.min(indexedTensor2.dimensionSizes().size(indexOfDimension2.get().intValue()), indexedTensor.dimensionSizes().size(indexOfDimension.get().intValue())));
            } else if (indexOfDimension.isPresent()) {
                builder.set(i, indexedTensor.dimensionSizes().size(indexOfDimension.get().intValue()));
            } else if (indexOfDimension2.isPresent()) {
                builder.set(i, indexedTensor2.dimensionSizes().size(indexOfDimension2.get().intValue()));
            }
        }
        return builder.build();
    }

    private static Tensor generalSubspaceJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, boolean z, DoubleBinaryOperator doubleBinaryOperator) {
        int[] subspaceIndexes = subspaceIndexes(tensor2.type(), tensor.type());
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator = tensor2.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            Double asDouble = tensor.getAsDouble(next.getKey().partialCopy(subspaceIndexes));
            if (asDouble != null) {
                of.cell(next.getKey(), z ? doubleBinaryOperator.applyAsDouble(next.getValue().doubleValue(), asDouble.doubleValue()) : doubleBinaryOperator.applyAsDouble(asDouble.doubleValue(), next.getValue().doubleValue()));
            }
        }
        return of.build();
    }

    private static int[] subspaceIndexes(TensorType tensorType, TensorType tensorType2) {
        int[] iArr = new int[tensorType2.dimensions().size()];
        for (int i = 0; i < tensorType2.dimensions().size(); i++) {
            iArr[i] = tensorType.indexOfDimension(tensorType2.dimensions().get(i).name()).get().intValue();
        }
        return iArr;
    }

    private static Tensor generalJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        return ((tensor instanceof IndexedTensor) && (tensor2 instanceof IndexedTensor)) ? indexedGeneralJoin((IndexedTensor) tensor, (IndexedTensor) tensor2, tensorType, doubleBinaryOperator) : mappedHashJoin(tensor, tensor2, tensorType, doubleBinaryOperator);
    }

    private static Tensor indexedGeneralJoin(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        DimensionSizes joinedSize = joinedSize(tensorType, indexedTensor, indexedTensor2);
        Tensor.Builder of = Tensor.Builder.of(tensorType, joinedSize);
        joinTo(indexedTensor, indexedTensor2, tensorType, joinedSize, mapIndexes(indexedTensor.type(), tensorType), mapIndexes(indexedTensor2.type(), tensorType), of, doubleBinaryOperator);
        return of.build();
    }

    private static void joinTo(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, DimensionSizes dimensionSizes, int[] iArr, int[] iArr2, Tensor.Builder builder, DoubleBinaryOperator doubleBinaryOperator) {
        Set copyOf = Set.copyOf(Sets.intersection(indexedTensor.type().dimensionNames(), indexedTensor2.type().dimensionNames()));
        int size = copyOf.size();
        Set<String> copyOf2 = Set.copyOf(Sets.difference(indexedTensor.type().dimensionNames(), indexedTensor2.type().dimensionNames()));
        DimensionSizes joinedSizeOf = joinedSizeOf(indexedTensor.type(), tensorType, dimensionSizes);
        DimensionSizes joinedSizeOf2 = joinedSizeOf(indexedTensor2.type(), tensorType, dimensionSizes);
        Iterator<IndexedTensor.SubspaceIterator> subspaceIterator = indexedTensor.subspaceIterator(copyOf2, joinedSizeOf);
        while (subspaceIterator.hasNext()) {
            IndexedTensor.SubspaceIterator next = subspaceIterator.next();
            while (next.hasNext()) {
                Tensor.Cell next2 = next.next();
                IndexedTensor.SubspaceIterator cellIterator = indexedTensor2.cellIterator(size > 0 ? partialAddress(indexedTensor.type(), next.address(), copyOf, size) : empty, joinedSizeOf2);
                while (cellIterator.hasNext()) {
                    Tensor.Cell next3 = cellIterator.next();
                    builder.cell(joinAddresses(next2.getKey(), iArr, next3.getKey(), iArr2, tensorType), doubleBinaryOperator.applyAsDouble(next2.getValue().doubleValue(), next3.getValue().doubleValue()));
                }
            }
        }
    }

    private static PartialAddress partialAddress(TensorType tensorType, TensorAddress tensorAddress, Set<String> set, int i) {
        PartialAddress.Builder builder = new PartialAddress.Builder(i);
        for (int i2 = 0; i2 < tensorType.dimensions().size(); i2++) {
            String name = tensorType.dimensions().get(i2).name();
            if (set.contains(name)) {
                builder.add(name, tensorAddress.objectLabel(i2));
            }
        }
        return builder.build();
    }

    private static DimensionSizes joinedSizeOf(TensorType tensorType, TensorType tensorType2, DimensionSizes dimensionSizes) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        int i = 0;
        for (int i2 = 0; i2 < tensorType2.dimensions().size(); i2++) {
            if (tensorType.dimensionNames().contains(tensorType2.dimensions().get(i2).name())) {
                int i3 = i;
                i++;
                builder.set(i3, dimensionSizes.size(i2));
            }
        }
        return builder.build();
    }

    private static Tensor mappedGeneralJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        int[] mapIndexes = mapIndexes(tensor.type(), tensorType);
        int[] mapIndexes2 = mapIndexes(tensor2.type(), tensorType);
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            Iterator<Tensor.Cell> cellIterator2 = tensor2.cellIterator();
            while (cellIterator2.hasNext()) {
                Tensor.Cell next2 = cellIterator2.next();
                TensorAddress joinAddresses = joinAddresses(next.getKey(), mapIndexes, next2.getKey(), mapIndexes2, tensorType);
                if (joinAddresses != null) {
                    of.cell(joinAddresses, doubleBinaryOperator.applyAsDouble(next.getValue().doubleValue(), next2.getValue().doubleValue()));
                }
            }
        }
        return of.build();
    }

    private static Tensor mappedHashJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        TensorType commonDimensions = commonDimensions(tensor, tensor2);
        if (commonDimensions.dimensions().isEmpty()) {
            return mappedGeneralJoin(tensor, tensor2, tensorType, doubleBinaryOperator);
        }
        boolean z = tensor.size() > tensor2.size();
        if (z) {
            tensor = tensor2;
            tensor2 = tensor;
        }
        int[] mapIndexes = mapIndexes(commonDimensions, tensor.type());
        int[] mapIndexes2 = mapIndexes(commonDimensions, tensor2.type());
        int[] mapIndexes3 = mapIndexes(tensor.type(), tensorType);
        int[] mapIndexes4 = mapIndexes(tensor2.type(), tensorType);
        HashMap hashMap = new HashMap(tensor.sizeAsInt());
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            ((List) hashMap.computeIfAbsent(next.getKey().partialCopy(mapIndexes), tensorAddress -> {
                return new ArrayList();
            })).add(next);
        }
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator2 = tensor2.cellIterator();
        while (cellIterator2.hasNext()) {
            Tensor.Cell next2 = cellIterator2.next();
            for (Tensor.Cell cell : (List) hashMap.getOrDefault(next2.getKey().partialCopy(mapIndexes2), List.of())) {
                TensorAddress joinAddresses = joinAddresses(cell.getKey(), mapIndexes3, next2.getKey(), mapIndexes4, tensorType);
                if (joinAddresses != null) {
                    of.cell(joinAddresses, z ? doubleBinaryOperator.applyAsDouble(next2.getValue().doubleValue(), cell.getValue().doubleValue()) : doubleBinaryOperator.applyAsDouble(cell.getValue().doubleValue(), next2.getValue().doubleValue()));
                }
            }
        }
        return of.build();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static int[] mapIndexes(TensorType tensorType, TensorType tensorType2) {
        int[] iArr = new int[tensorType.dimensions().size()];
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            iArr[i] = tensorType2.indexOfDimensionAsInt(tensorType.dimensions().get(i).name());
        }
        return iArr;
    }

    private static TensorAddress joinAddresses(TensorAddress tensorAddress, int[] iArr, TensorAddress tensorAddress2, int[] iArr2, TensorType tensorType) {
        Label[] labelArr = new Label[tensorType.dimensions().size()];
        Arrays.fill(labelArr, LabelCache.INVALID_INDEX_LABEL);
        mapContent(tensorAddress, labelArr, iArr);
        if (mapContent(tensorAddress2, labelArr, iArr2)) {
            return TensorAddressAny.ofUnsafe(labelArr);
        }
        return null;
    }

    private static boolean mapContent(TensorAddress tensorAddress, Label[] labelArr, int[] iArr) {
        int size = tensorAddress.size();
        for (int i = 0; i < size; i++) {
            int i2 = iArr[i];
            Label objectLabel = tensorAddress.objectLabel(i);
            if (!labelArr[i2].isEqualTo(LabelCache.INVALID_INDEX_LABEL) && !labelArr[i2].isEqualTo(objectLabel)) {
                return false;
            }
            labelArr[i2] = objectLabel;
        }
        return true;
    }

    private static TensorType commonDimensions(Tensor tensor, Tensor tensor2) {
        TensorType type = tensor.type();
        TensorType type2 = tensor2.type();
        TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(type, type2));
        for (int i = 0; i < type.dimensions().size(); i++) {
            TensorType.Dimension dimension = type.dimensions().get(i);
            for (int i2 = 0; i2 < type2.dimensions().size(); i2++) {
                TensorType.Dimension dimension2 = type2.dimensions().get(i2);
                if (dimension.equals(dimension2)) {
                    builder.set(dimension2);
                }
            }
        }
        return builder.build();
    }
}
