package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

@Beta
/* loaded from: input_file:com/yahoo/tensor/functions/Concat.class */
public class Concat extends PrimitiveTensorFunction {
    private final TensorFunction argumentA;
    private final TensorFunction argumentB;
    private final String dimension;

    public Concat(TensorFunction tensorFunction, TensorFunction tensorFunction2, String str) {
        Objects.requireNonNull(tensorFunction, "The first argument tensor cannot be null");
        Objects.requireNonNull(tensorFunction2, "The second argument tensor cannot be null");
        Objects.requireNonNull(str, "The dimension cannot be null");
        this.argumentA = tensorFunction;
        this.argumentB = tensorFunction2;
        this.dimension = str;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction> functionArguments() {
        return ImmutableList.of(this.argumentA, this.argumentB);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction replaceArguments(List<TensorFunction> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Concat must have 2 arguments, got " + list.size());
        }
        return new Concat(list.get(0), list.get(1), this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction toPrimitive() {
        return new Concat(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext toStringContext) {
        return "concat(" + this.argumentA.toString(toStringContext) + ", " + this.argumentB.toString(toStringContext) + ", " + this.dimension + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext evaluationContext) {
        Tensor evaluate = this.argumentA.evaluate(evaluationContext);
        Tensor evaluate2 = this.argumentB.evaluate(evaluationContext);
        Tensor ensureIndexedDimension = ensureIndexedDimension(this.dimension, evaluate);
        Tensor ensureIndexedDimension2 = ensureIndexedDimension(this.dimension, evaluate2);
        IndexedTensor indexedTensor = (IndexedTensor) ensureIndexedDimension;
        IndexedTensor indexedTensor2 = (IndexedTensor) ensureIndexedDimension2;
        TensorType concatType = concatType(ensureIndexedDimension, ensureIndexedDimension2);
        Tensor.Builder of = Tensor.Builder.of(concatType, concatSize(concatType, indexedTensor, indexedTensor2, this.dimension));
        int intValue = ((Integer) indexedTensor.type().indexOfDimension(this.dimension).map(num -> {
            return Integer.valueOf(indexedTensor.dimensionSizes().size(num.intValue()));
        }).orElseThrow(RuntimeException::new)).intValue();
        int[] mapIndexes = mapIndexes(ensureIndexedDimension.type(), concatType);
        int[] mapIndexes2 = mapIndexes(ensureIndexedDimension2.type(), concatType);
        concatenateTo(indexedTensor, indexedTensor2, intValue, concatType, mapIndexes, mapIndexes2, of);
        concatenateTo(indexedTensor2, indexedTensor, 0, concatType, mapIndexes2, mapIndexes, of);
        return of.build();
    }

    private void concatenateTo(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, int i, TensorType tensorType, int[] iArr, int[] iArr2, Tensor.Builder builder) {
        Set<String> set = (Set) indexedTensor.type().dimensionNames().stream().filter(str -> {
            return !str.equals(this.dimension);
        }).collect(Collectors.toSet());
        Iterator<IndexedTensor.SubspaceIterator> subspaceIterator = indexedTensor.subspaceIterator(set);
        while (subspaceIterator.hasNext()) {
            IndexedTensor.SubspaceIterator next = subspaceIterator.next();
            TensorAddress address = next.address();
            Iterator<IndexedTensor.SubspaceIterator> subspaceIterator2 = indexedTensor2.subspaceIterator(set);
            while (subspaceIterator2.hasNext()) {
                IndexedTensor.SubspaceIterator next2 = subspaceIterator2.next();
                while (next2.hasNext()) {
                    Tensor.Cell next3 = next2.next();
                    TensorAddress combineAddresses = combineAddresses(address, iArr, next3.getKey(), iArr2, tensorType, i, this.dimension);
                    if (combineAddresses != null) {
                        builder.cell(combineAddresses, next3.getValue().doubleValue());
                    }
                }
                next.reset();
            }
        }
    }

    private Tensor ensureIndexedDimension(String str, Tensor tensor) {
        Optional<TensorType.Dimension> dimension = tensor.type().dimension(str);
        if (dimension.isPresent()) {
            if (dimension.get().isIndexed()) {
                return tensor;
            }
            throw new IllegalArgumentException("Concat in dimension '" + str + "' requires that dimension to be indexed or absent, but got a tensor with type " + tensor.type());
        }
        if (tensor.type().dimensions().stream().anyMatch(dimension2 -> {
            return !dimension2.isIndexed();
        })) {
            throw new IllegalArgumentException("Concat requires an indexed tensor, but got a tensor with type " + tensor.type());
        }
        return tensor.multiply(Tensor.Builder.of(new TensorType.Builder().indexed(str, 1).build()).cell(1.0d, 0).build());
    }

    private TensorType concatType(Tensor tensor, Tensor tensor2) {
        TensorType.Builder builder = new TensorType.Builder(tensor.type(), tensor2.type());
        if (builder.getDimension(this.dimension).get().size().isPresent()) {
            builder.set(TensorType.Dimension.indexed(this.dimension, tensor.type().dimension(this.dimension).get().size().get().intValue() + tensor2.type().dimension(this.dimension).get().size().get().intValue()));
        }
        return builder.build();
    }

    private DimensionSizes concatSize(TensorType tensorType, IndexedTensor indexedTensor, IndexedTensor indexedTensor2, String str) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); i++) {
            String name = tensorType.dimensions().get(i).name();
            int intValue = ((Integer) indexedTensor.type().indexOfDimension(name).map(num -> {
                return Integer.valueOf(indexedTensor.dimensionSizes().size(num.intValue()));
            }).orElse(0)).intValue();
            int intValue2 = ((Integer) indexedTensor2.type().indexOfDimension(name).map(num2 -> {
                return Integer.valueOf(indexedTensor2.dimensionSizes().size(num2.intValue()));
            }).orElse(0)).intValue();
            if (name.equals(str)) {
                builder.set(i, intValue + intValue2);
            } else {
                if (intValue != 0 && intValue2 != 0 && intValue != intValue2) {
                    throw new IllegalArgumentException("Dimension " + name + " must be of the same size when concatenating " + indexedTensor.type() + " and " + indexedTensor2.type() + " along dimension " + str + ", but was " + intValue + " and " + intValue2);
                }
                builder.set(i, Math.max(intValue, intValue2));
            }
        }
        return builder.build();
    }

    private TensorAddress combineAddresses(TensorAddress tensorAddress, int[] iArr, TensorAddress tensorAddress2, int[] iArr2, TensorType tensorType, int i, String str) {
        int[] iArr3 = new int[tensorType.dimensions().size()];
        Arrays.fill(iArr3, -1);
        int intValue = tensorType.indexOfDimension(str).get().intValue();
        mapContent(tensorAddress, iArr3, iArr, intValue, i);
        if (mapContent(tensorAddress2, iArr3, iArr2, intValue, i)) {
            return TensorAddress.of(iArr3);
        }
        return null;
    }

    private int[] mapIndexes(TensorType tensorType, TensorType tensorType2) {
        int[] iArr = new int[tensorType.dimensions().size()];
        for (int i = 0; i < tensorType.dimensions().size(); i++) {
            iArr[i] = tensorType2.indexOfDimension(tensorType.dimensions().get(i).name()).orElse(-1).intValue();
        }
        return iArr;
    }

    private boolean mapContent(TensorAddress tensorAddress, int[] iArr, int[] iArr2, int i, int i2) {
        for (int i3 = 0; i3 < tensorAddress.size(); i3++) {
            int i4 = iArr2[i3];
            if (i == i4) {
                iArr[i4] = tensorAddress.intLabel(i3) + i2;
            } else {
                if (iArr[i4] != -1 && iArr[i4] != tensorAddress.intLabel(i3)) {
                    return false;
                }
                iArr[i4] = tensorAddress.intLabel(i3);
            }
        }
        return true;
    }
}
