package com.yahoo.tensor.functions;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;

/* loaded from: input_file:com/yahoo/tensor/functions/Merge.class */
public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final DoubleBinaryOperator merger;

    public Merge(TensorFunction<NAMETYPE> tensorFunction, TensorFunction<NAMETYPE> tensorFunction2, DoubleBinaryOperator doubleBinaryOperator) {
        Objects.requireNonNull(tensorFunction, "The first argument tensor cannot be null");
        Objects.requireNonNull(tensorFunction2, "The second argument tensor cannot be null");
        Objects.requireNonNull(doubleBinaryOperator, "The merger function cannot be null");
        this.argumentA = tensorFunction;
        this.argumentB = tensorFunction2;
        this.merger = doubleBinaryOperator;
    }

    public static TensorType outputType(TensorType tensorType, TensorType tensorType2) {
        return TypeResolver.merge(tensorType, tensorType2);
    }

    public DoubleBinaryOperator merger() {
        return this.merger;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argumentA, this.argumentB);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Merge must have 2 arguments, got " + list.size());
        }
        return new Merge(list.get(0), list.get(1), this.merger);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Merge(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.merger);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        return outputType(this.argumentA.type(typeContext), this.argumentB.type(typeContext));
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        Tensor evaluate = this.argumentA.evaluate(evaluationContext);
        Tensor evaluate2 = this.argumentB.evaluate(evaluationContext);
        return evaluate(evaluate, evaluate2, outputType(evaluate.type(), evaluate2.type()), this.merger);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return "merge(" + this.argumentA.toString(toStringContext) + ", " + this.argumentB.toString(toStringContext) + ", " + this.merger + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("merge", this.argumentA, this.argumentB, this.merger);
    }

    static Tensor evaluate(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        return (hasSingleIndexedDimension(tensor) && hasSingleIndexedDimension(tensor2) && tensor.type().dimensions().get(0).name().equals(tensor2.type().dimensions().get(0).name())) ? indexedVectorMerge((IndexedTensor) tensor, (IndexedTensor) tensor2, tensorType, doubleBinaryOperator) : generalMerge(tensor, tensor2, tensorType, doubleBinaryOperator);
    }

    private static boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private static Tensor indexedVectorMerge(IndexedTensor indexedTensor, IndexedTensor indexedTensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        long size = indexedTensor.dimensionSizes().size(0);
        long size2 = indexedTensor2.dimensionSizes().size(0);
        long max = Math.max(size, size2);
        long min = Math.min(size, size2);
        Iterator<Double> valueIterator = indexedTensor.valueIterator();
        Iterator<Double> valueIterator2 = indexedTensor2.valueIterator();
        IndexedTensor.Builder of = IndexedTensor.Builder.of(tensorType);
        long j = 0;
        while (true) {
            long j2 = j;
            if (j2 >= min) {
                break;
            }
            of.cell(doubleBinaryOperator.applyAsDouble(valueIterator.next().doubleValue(), valueIterator2.next().doubleValue()), j2);
            j = j2 + 1;
        }
        Iterator<Double> it = size > size2 ? valueIterator : valueIterator2;
        long j3 = min;
        while (true) {
            long j4 = j3;
            if (j4 >= max) {
                return of.build();
            }
            of.cell(it.next().doubleValue(), j4);
            j3 = j4 + 1;
        }
    }

    private static Tensor generalMerge(Tensor tensor, Tensor tensor2, TensorType tensorType, DoubleBinaryOperator doubleBinaryOperator) {
        Tensor.Builder of = Tensor.Builder.of(tensorType);
        addCellsOf(tensor, tensor2, of, doubleBinaryOperator);
        addCellsOf(tensor2, tensor, of, null);
        return of.build();
    }

    private static void addCellsOf(Tensor tensor, Tensor tensor2, Tensor.Builder builder, DoubleBinaryOperator doubleBinaryOperator) {
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell next = cellIterator.next();
            TensorAddress key = next.getKey();
            if (!tensor2.has(key)) {
                builder.cell(key, next.getValue().doubleValue());
            } else if (doubleBinaryOperator != null) {
                builder.cell(key, doubleBinaryOperator.applyAsDouble(next.getValue().doubleValue(), tensor2.get(key)));
            }
        }
    }
}
