package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;

@Beta
/* loaded from: input_file:com/yahoo/tensor/functions/Join.class */
public class Join extends PrimitiveTensorFunction {
    private final TensorFunction argumentA;
    private final TensorFunction argumentB;
    private final DoubleBinaryOperator combinator;

    public Join(TensorFunction tensorFunction, TensorFunction tensorFunction2, DoubleBinaryOperator doubleBinaryOperator) {
        Objects.requireNonNull(tensorFunction, "The first argument tensor cannot be null");
        Objects.requireNonNull(tensorFunction2, "The second argument tensor cannot be null");
        Objects.requireNonNull(doubleBinaryOperator, "The combinator function cannot be null");
        this.argumentA = tensorFunction;
        this.argumentB = tensorFunction2;
        this.combinator = doubleBinaryOperator;
    }

    public TensorFunction argumentA() {
        return this.argumentA;
    }

    public TensorFunction argumentB() {
        return this.argumentB;
    }

    public DoubleBinaryOperator combinator() {
        return this.combinator;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction> functionArguments() {
        return ImmutableList.of(this.argumentA, this.argumentB);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction replaceArguments(List<TensorFunction> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Join must have 2 arguments, got " + list.size());
        }
        return new Join(list.get(0), list.get(1), this.combinator);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction toPrimitive() {
        return new Join(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.combinator);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext toStringContext) {
        return "join(" + this.argumentA.toString(toStringContext) + ", " + this.argumentB.toString(toStringContext) + ", " + this.combinator + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext evaluationContext) {
        Tensor evaluate = this.argumentA.evaluate(evaluationContext);
        Tensor evaluate2 = this.argumentB.evaluate(evaluationContext);
        TensorType build = new TensorType.Builder(evaluate.type(), evaluate2.type()).build();
        return (hasSingleIndexedDimension(evaluate) && hasSingleIndexedDimension(evaluate2) && evaluate.type().dimensions().get(0).name().equals(evaluate2.type().dimensions().get(0).name())) ? indexedVectorJoin((IndexedTensor) evaluate, (IndexedTensor) evaluate2, build) : (build.dimensions().size() == evaluate.type().dimensions().size() && build.dimensions().size() == evaluate2.type().dimensions().size()) ? singleSpaceJoin(evaluate, evaluate2, build) : evaluate.type().dimensions().containsAll(evaluate2.type().dimensions()) ? subspaceJoin(evaluate2, evaluate, build, true) : evaluate2.type().dimensions().containsAll(evaluate.type().dimensions()) ? subspaceJoin(evaluate, evaluate2, build, false) : generalJoin(evaluate, evaluate2, build);
    }

    private boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private Tensor indexedVectorJoin(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType) {
        int min = Math.min(indexedTensor.dimensionSizes().size(0), indexedTensor2.dimensionSizes().size(0));
        Iterator<Double> valueIterator = indexedTensor.valueIterator();
        Iterator<Double> valueIterator2 = indexedTensor2.valueIterator();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(tensorType, new DimensionSizes.Builder(1).set(0, min).build());
        for (int i = 0; i < min; i++) {
            of.cell(this.combinator.applyAsDouble(valueIterator.next().doubleValue(), valueIterator2.next().doubleValue()), i);
        }
        return of.build();
    }

    private Tensor singleSpaceJoin(Tensor tensor, Tensor tensor2, TensorType tensorType) {
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            double d = tensor2.get(next.getKey());
            if (!Double.isNaN(d)) {
                of.cell(next.getKey(), this.combinator.applyAsDouble(next.getValue().doubleValue(), d));
            }
        }
        return of.build();
    }

    private Tensor subspaceJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, boolean z) {
        return ((tensor instanceof IndexedTensor) && (tensor2 instanceof IndexedTensor)) ? indexedSubspaceJoin((IndexedTensor) tensor, (IndexedTensor) tensor2, tensorType, z) : generalSubspaceJoin(tensor, tensor2, tensorType, z);
    }

    private Tensor indexedSubspaceJoin(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, boolean z) {
        if (indexedTensor.size() == 0 || indexedTensor2.size() == 0) {
            return Tensor.Builder.of(tensorType, new DimensionSizes.Builder(tensorType.dimensions().size()).build()).build();
        }
        DimensionSizes joinedSize = joinedSize(tensorType, indexedTensor, indexedTensor2);
        IndexedTensor.Builder builder = (IndexedTensor.Builder) Tensor.Builder.of(tensorType, joinedSize);
        HashSet hashSet = new HashSet(indexedTensor2.type().dimensionNames());
        hashSet.removeAll(indexedTensor.type().dimensionNames());
        Iterator<IndexedTensor.SubspaceIterator> subspaceIterator = indexedTensor2.subspaceIterator(hashSet, joinedSize);
        while (subspaceIterator.hasNext()) {
            IndexedTensor.SubspaceIterator next = subspaceIterator.next();
            joinSubspaces(indexedTensor.valueIterator(), indexedTensor.size(), next, next.size(), z, builder);
        }
        return builder.build();
    }

    private void joinSubspaces(Iterator<Double> it, int i, Iterator<Tensor.Cell> it2, int i2, boolean z, IndexedTensor.Builder builder) {
        int min = Math.min(i, i2);
        if (z) {
            for (int i3 = 0; i3 < min; i3++) {
                Tensor.Cell next = it2.next();
                builder.cell(next, this.combinator.applyAsDouble(next.getValue().doubleValue(), it.next().doubleValue()));
            }
            return;
        }
        for (int i4 = 0; i4 < min; i4++) {
            Tensor.Cell next2 = it2.next();
            builder.cell(next2, this.combinator.applyAsDouble(it.next().doubleValue(), next2.getValue().doubleValue()));
        }
    }

    private DimensionSizes joinedSize(TensorType tensorType, IndexedTensor indexedTensor, IndexedTensor indexedTensor2) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); i++) {
            String name = tensorType.dimensions().get(i).name();
            Optional<Integer> indexOfDimension = indexedTensor.type().indexOfDimension(name);
            Optional<Integer> indexOfDimension2 = indexedTensor2.type().indexOfDimension(name);
            if (indexOfDimension.isPresent() && indexOfDimension2.isPresent()) {
                builder.set(i, Math.min(indexedTensor2.dimensionSizes().size(indexOfDimension2.get().intValue()), indexedTensor.dimensionSizes().size(indexOfDimension.get().intValue())));
            } else if (indexOfDimension.isPresent()) {
                builder.set(i, indexedTensor.dimensionSizes().size(indexOfDimension.get().intValue()));
            } else if (indexOfDimension2.isPresent()) {
                builder.set(i, indexedTensor2.dimensionSizes().size(indexOfDimension2.get().intValue()));
            }
        }
        return builder.build();
    }

    private Tensor generalSubspaceJoin(Tensor tensor, Tensor tensor2, TensorType tensorType, boolean z) {
        int[] subspaceIndexes = subspaceIndexes(tensor2.type(), tensor.type());
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator = tensor2.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            double d = tensor.get(mapAddressToSubspace(next.getKey(), subspaceIndexes));
            if (!Double.isNaN(d)) {
                of.cell(next.getKey(), z ? this.combinator.applyAsDouble(next.getValue().doubleValue(), d) : this.combinator.applyAsDouble(d, next.getValue().doubleValue()));
            }
        }
        return of.build();
    }

    private int[] subspaceIndexes(TensorType tensorType, TensorType tensorType2) {
        int[] iArr = new int[tensorType2.dimensions().size()];
        for (int i = 0; i < tensorType2.dimensions().size(); i++) {
            iArr[i] = tensorType.indexOfDimension(tensorType2.dimensions().get(i).name()).get().intValue();
        }
        return iArr;
    }

    private TensorAddress mapAddressToSubspace(TensorAddress tensorAddress, int[] iArr) {
        String[] strArr = new String[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            strArr[i] = tensorAddress.label(iArr[i]);
        }
        return TensorAddress.of(strArr);
    }

    private Tensor generalJoin(Tensor tensor, Tensor tensor2, TensorType tensorType) {
        return ((tensor instanceof IndexedTensor) && (tensor2 instanceof IndexedTensor)) ? indexedGeneralJoin((IndexedTensor) tensor, (IndexedTensor) tensor2, tensorType) : mappedHashJoin(tensor, tensor2, tensorType);
    }

    private Tensor indexedGeneralJoin(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType) {
        DimensionSizes joinedSize = joinedSize(tensorType, indexedTensor, indexedTensor2);
        Tensor.Builder of = Tensor.Builder.of(tensorType, joinedSize);
        int[] mapIndexes = mapIndexes(indexedTensor.type(), tensorType);
        int[] mapIndexes2 = mapIndexes(indexedTensor2.type(), tensorType);
        joinTo(indexedTensor, indexedTensor2, tensorType, joinedSize, mapIndexes, mapIndexes2, false, of);
        joinTo(indexedTensor2, indexedTensor, tensorType, joinedSize, mapIndexes2, mapIndexes, true, of);
        return of.build();
    }

    private void joinTo(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, DimensionSizes dimensionSizes, int[] iArr, int[] iArr2, boolean z, Tensor.Builder builder) {
        Sets.SetView intersection = Sets.intersection(indexedTensor.type().dimensionNames(), indexedTensor2.type().dimensionNames());
        Sets.SetView difference = Sets.difference(indexedTensor.type().dimensionNames(), indexedTensor2.type().dimensionNames());
        DimensionSizes joinedSizeOf = joinedSizeOf(indexedTensor.type(), tensorType, dimensionSizes);
        DimensionSizes joinedSizeOf2 = joinedSizeOf(indexedTensor2.type(), tensorType, dimensionSizes);
        Iterator<IndexedTensor.SubspaceIterator> subspaceIterator = indexedTensor.subspaceIterator(difference, joinedSizeOf);
        while (subspaceIterator.hasNext()) {
            IndexedTensor.SubspaceIterator next = subspaceIterator.next();
            while (next.hasNext()) {
                Tensor.Cell next2 = next.next();
                IndexedTensor.SubspaceIterator cellIterator = indexedTensor2.cellIterator(partialAddress(indexedTensor.type(), next.address(), intersection), joinedSizeOf2);
                while (cellIterator.hasNext()) {
                    Tensor.Cell next3 = cellIterator.next();
                    builder.cell(joinAddresses(next2.getKey(), iArr, next3.getKey(), iArr2, tensorType), z ? this.combinator.applyAsDouble(next3.getValue().doubleValue(), next2.getValue().doubleValue()) : this.combinator.applyAsDouble(next2.getValue().doubleValue(), next3.getValue().doubleValue()));
                }
            }
        }
    }

    private PartialAddress partialAddress(TensorType tensorType, TensorAddress tensorAddress, Set<String> set) {
        PartialAddress.Builder builder = new PartialAddress.Builder(set.size());
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            if (set.contains(tensorType.dimensions().get(i).name())) {
                builder.add(tensorType.dimensions().get(i).name(), tensorAddress.intLabel(i));
            }
        }
        return builder.build();
    }

    private DimensionSizes joinedSizeOf(TensorType tensorType, TensorType tensorType2, DimensionSizes dimensionSizes) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        int i = 0;
        for (int i2 = 0; i2 < tensorType2.dimensions().size(); i2++) {
            if (tensorType.dimensionNames().contains(tensorType2.dimensions().get(i2).name())) {
                int i3 = i;
                i++;
                builder.set(i3, dimensionSizes.size(i2));
            }
        }
        return builder.build();
    }

    private Tensor mappedGeneralJoin(Tensor tensor, Tensor tensor2, TensorType tensorType) {
        int[] mapIndexes = mapIndexes(tensor.type(), tensorType);
        int[] mapIndexes2 = mapIndexes(tensor2.type(), tensorType);
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            Iterator<Tensor.Cell> cellIterator2 = tensor2.cellIterator();
            while (cellIterator2.hasNext()) {
                Tensor.Cell next2 = cellIterator2.next();
                TensorAddress joinAddresses = joinAddresses(next.getKey(), mapIndexes, next2.getKey(), mapIndexes2, tensorType);
                if (joinAddresses != null) {
                    of.cell(joinAddresses, this.combinator.applyAsDouble(next.getValue().doubleValue(), next2.getValue().doubleValue()));
                }
            }
        }
        return of.build();
    }

    private Tensor mappedHashJoin(Tensor tensor, Tensor tensor2, TensorType tensorType) {
        TensorType commonDimensions = commonDimensions(tensor, tensor2);
        if (commonDimensions.dimensions().isEmpty()) {
            return mappedGeneralJoin(tensor, tensor2, tensorType);
        }
        boolean z = tensor.size() > tensor2.size();
        if (z) {
            tensor = tensor2;
            tensor2 = tensor;
        }
        int[] mapIndexes = mapIndexes(commonDimensions, tensor.type());
        int[] mapIndexes2 = mapIndexes(commonDimensions, tensor2.type());
        int[] mapIndexes3 = mapIndexes(tensor.type(), tensorType);
        int[] mapIndexes4 = mapIndexes(tensor2.type(), tensorType);
        HashMap hashMap = new HashMap();
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            TensorAddress partialCommonAddress = partialCommonAddress(next, mapIndexes);
            hashMap.putIfAbsent(partialCommonAddress, new ArrayList());
            ((List) hashMap.get(partialCommonAddress)).add(next);
        }
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        Iterator<Tensor.Cell> cellIterator2 = tensor2.cellIterator();
        while (cellIterator2.hasNext()) {
            Tensor.Cell next2 = cellIterator2.next();
            for (Tensor.Cell cell : (List) hashMap.getOrDefault(partialCommonAddress(next2, mapIndexes2), Collections.emptyList())) {
                TensorAddress joinAddresses = joinAddresses(cell.getKey(), mapIndexes3, next2.getKey(), mapIndexes4, tensorType);
                if (joinAddresses != null) {
                    of.cell(joinAddresses, z ? this.combinator.applyAsDouble(next2.getValue().doubleValue(), cell.getValue().doubleValue()) : this.combinator.applyAsDouble(cell.getValue().doubleValue(), next2.getValue().doubleValue()));
                }
            }
        }
        return of.build();
    }

    private int[] mapIndexes(TensorType tensorType, TensorType tensorType2) {
        int[] iArr = new int[tensorType.dimensions().size()];
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            iArr[i] = tensorType2.indexOfDimension(tensorType.dimensions().get(i).name()).orElse(-1).intValue();
        }
        return iArr;
    }

    private TensorAddress joinAddresses(TensorAddress tensorAddress, int[] iArr, TensorAddress tensorAddress2, int[] iArr2, TensorType tensorType) {
        String[] strArr = new String[tensorType.dimensions().size()];
        mapContent(tensorAddress, strArr, iArr);
        if (mapContent(tensorAddress2, strArr, iArr2)) {
            return TensorAddress.of(strArr);
        }
        return null;
    }

    private boolean mapContent(TensorAddress tensorAddress, String[] strArr, int[] iArr) {
        for (int i = 0; i < tensorAddress.size(); i++) {
            int i2 = iArr[i];
            if (strArr[i2] != null && !strArr[i2].equals(tensorAddress.label(i))) {
                return false;
            }
            strArr[i2] = tensorAddress.label(i);
        }
        return true;
    }

    private TensorType commonDimensions(Tensor tensor, Tensor tensor2) {
        TensorType.Builder builder = new TensorType.Builder();
        TensorType type = tensor.type();
        TensorType type2 = tensor2.type();
        for (int i = 0; i < type.dimensions().size(); i++) {
            TensorType.Dimension dimension = type.dimensions().get(i);
            for (int i2 = 0; i2 < type2.dimensions().size(); i2++) {
                TensorType.Dimension dimension2 = type2.dimensions().get(i2);
                if (dimension.equals(dimension2)) {
                    builder.set(dimension2);
                }
            }
        }
        return builder.build();
    }

    private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] iArr) {
        TensorAddress key = cell.getKey();
        String[] strArr = new String[iArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = key.label(iArr[i]);
        }
        return TensorAddress.of(strArr);
    }
}
