package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Reduce;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:com/yahoo/tensor/functions/EuclideanDistance.class */
public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
    private final TensorFunction<NAMETYPE> arg1;
    private final TensorFunction<NAMETYPE> arg2;
    private final String dimension;

    public EuclideanDistance(TensorFunction<NAMETYPE> tensorFunction, TensorFunction<NAMETYPE> tensorFunction2, String str) {
        this.arg1 = tensorFunction;
        this.arg2 = tensorFunction2;
        this.dimension = str;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return List.of(this.arg1, this.arg2);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("EuclideanDistance must have 2 arguments, got " + list.size());
        }
        return new EuclideanDistance(list.get(0), list.get(1), this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        TensorType type = this.arg1.toPrimitive().type(typeContext);
        TensorType type2 = this.arg2.toPrimitive().type(typeContext);
        String resolveBinding = typeContext.resolveBinding(this.dimension);
        Optional<TensorType.Dimension> dimension = type.dimension(resolveBinding);
        Optional<TensorType.Dimension> dimension2 = type2.dimension(resolveBinding);
        if (!dimension.isEmpty() && !dimension2.isEmpty() && dimension.get().type() == TensorType.Dimension.Type.indexedBound && dimension2.get().type() == TensorType.Dimension.Type.indexedBound && dimension.get().size().equals(dimension2.get().size())) {
            return toPrimitive().type(typeContext);
        }
        throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" + resolveBinding + "' dimension with same size, but input types were " + type + " and " + type2);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        return toPrimitive().evaluate(evaluationContext);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return new Map(new Reduce(new Map(new Join(this.arg1.toPrimitive(), this.arg2.toPrimitive(), ScalarFunctions.subtract()), ScalarFunctions.square()), Reduce.Aggregator.sum, this.dimension), ScalarFunctions.sqrt());
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return "euclidean_distance(" + this.arg1.toString(toStringContext) + ", " + this.arg2.toString(toStringContext) + ", " + toStringContext.resolveBinding(this.dimension) + ")";
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("euclidean_distance", this.arg1, this.arg2, this.dimension);
    }
}
