/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.tensor.functions;

import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;

public class Merge<NAMETYPE extends Name>
extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> argumentA;
    private final TensorFunction<NAMETYPE> argumentB;
    private final DoubleBinaryOperator merger;

    public Merge(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator merger) {
        Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
        Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
        Objects.requireNonNull(merger, "The merger function cannot be null");
        this.argumentA = argumentA;
        this.argumentB = argumentB;
        this.merger = merger;
    }

    public static TensorType outputType(TensorType a, TensorType b) {
        return TypeResolver.merge(a, b);
    }

    public DoubleBinaryOperator merger() {
        return this.merger;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.argumentA, this.argumentB);
    }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 2) {
            throw new IllegalArgumentException("Merge must have 2 arguments, got " + arguments.size());
        }
        return new Merge<NAMETYPE>(arguments.get(0), arguments.get(1), this.merger);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Merge<NAMETYPE>(this.argumentA.toPrimitive(), this.argumentB.toPrimitive(), this.merger);
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return Merge.outputType(this.argumentA.type(context), this.argumentB.type(context));
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor a = this.argumentA.evaluate(context);
        Tensor b = this.argumentB.evaluate(context);
        TensorType mergedType = Merge.outputType(a.type(), b.type());
        return Merge.evaluate(a, b, mergedType, this.merger);
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "merge(" + this.argumentA.toString(context) + ", " + this.argumentB.toString(context) + ", " + String.valueOf(this.merger) + ")";
    }

    @Override
    public int hashCode() {
        return Objects.hash("merge", this.argumentA, this.argumentB, this.merger);
    }

    static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
        if (Merge.hasSingleIndexedDimension(a) && Merge.hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) {
            return Merge.indexedVectorMerge((IndexedTensor)a, (IndexedTensor)b, mergedType, combinator);
        }
        return Merge.generalMerge(a, b, mergedType, combinator);
    }

    private static boolean hasSingleIndexedDimension(Tensor tensor) {
        return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
    }

    private static Tensor indexedVectorMerge(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
        long aSize = a.dimensionSizes().size(0);
        long bSize = b.dimensionSizes().size(0);
        long mergedSize = Math.max(aSize, bSize);
        long sharedSize = Math.min(aSize, bSize);
        Iterator<Double> aIterator = a.valueIterator();
        Iterator<Double> bIterator = b.valueIterator();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
        long i = 0L;
        while (i < sharedSize) {
            builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i++);
        }
        Iterator<Double> largestIterator = aSize > bSize ? aIterator : bIterator;
        long i2 = sharedSize;
        while (i2 < mergedSize) {
            builder.cell(largestIterator.next(), i2++);
        }
        return builder.build();
    }

    private static Tensor generalMerge(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
        Tensor.Builder builder = Tensor.Builder.of(mergedType);
        Merge.addCellsOf(a, b, builder, combinator);
        Merge.addCellsOf(b, a, builder, null);
        return builder.build();
    }

    private static void addCellsOf(Tensor a, Tensor b, Tensor.Builder builder, DoubleBinaryOperator combinator) {
        Iterator<Tensor.Cell> i = a.cellIterator();
        while (i.hasNext()) {
            Map.Entry aCell = i.next();
            TensorAddress key = (TensorAddress)aCell.getKey();
            Double bVal = b.getAsDouble(key);
            if (bVal == null) {
                builder.cell(key, (double)((Double)aCell.getValue()));
                continue;
            }
            if (combinator == null) continue;
            builder.cell(key, combinator.applyAsDouble((Double)aCell.getValue(), bVal));
        }
    }
}

