package com.yahoo.tensor.functions;

import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;

/* loaded from: input_file:com/yahoo/tensor/functions/Generate.class */
public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
    private final TensorType type;
    private final Function<List<Long>, Double> freeGenerator;
    private final ScalarFunction<NAMETYPE> boundGenerator;

    /* loaded from: input_file:com/yahoo/tensor/functions/Generate$GenerateEvaluationContext.class */
    private class GenerateEvaluationContext implements EvaluationContext<NAMETYPE> {
        private final TensorType type;
        private final EvaluationContext<NAMETYPE> context;
        private IndexedTensor.Indexes indexes;

        GenerateEvaluationContext(TensorType tensorType, EvaluationContext<NAMETYPE> evaluationContext) {
            this.type = tensorType;
            this.context = evaluationContext;
        }

        double apply(IndexedTensor.Indexes indexes) {
            if (Generate.this.freeGenerator != null) {
                return Generate.this.freeGenerator.apply(indexes.toList()).doubleValue();
            }
            this.indexes = indexes;
            return Generate.this.boundGenerator.apply((EvaluationContext) this).doubleValue();
        }

        @Override // com.yahoo.tensor.evaluation.EvaluationContext
        public Tensor getTensor(String str) {
            return this.type.indexOfDimension(str).isPresent() ? Tensor.from(this.indexes.indexesForReading()[r0.get().intValue()]) : this.context.getTensor(str);
        }

        @Override // com.yahoo.tensor.evaluation.TypeContext
        public TensorType getType(NAMETYPE nametype) {
            return this.type.indexOfDimension(nametype.name()).isPresent() ? TensorType.empty : this.context.getType((EvaluationContext<NAMETYPE>) nametype);
        }

        @Override // com.yahoo.tensor.evaluation.TypeContext
        public TensorType getType(String str) {
            return this.type.indexOfDimension(str).isPresent() ? TensorType.empty : this.context.getType(str);
        }

        @Override // com.yahoo.tensor.evaluation.TypeContext
        public String resolveBinding(String str) {
            return this.context.resolveBinding(str);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/yahoo/tensor/functions/Generate$GenerateToStringContext.class */
    public class GenerateToStringContext implements ToStringContext<NAMETYPE> {
        private final ToStringContext<NAMETYPE> context;

        public GenerateToStringContext(ToStringContext<NAMETYPE> toStringContext) {
            this.context = toStringContext;
        }

        @Override // com.yahoo.tensor.functions.ToStringContext
        public String getBinding(String str) {
            return Generate.this.type.dimension(str).isPresent() ? str : this.context.getBinding(str);
        }

        @Override // com.yahoo.tensor.functions.ToStringContext
        public ToStringContext<NAMETYPE> parent() {
            return this.context;
        }
    }

    public Generate(TensorType tensorType, Function<List<Long>, Double> function) {
        this(tensorType, (Function) Objects.requireNonNull(function), null);
    }

    public static <NAMETYPE extends Name> Generate<NAMETYPE> free(TensorType tensorType, Function<List<Long>, Double> function) {
        return new Generate<>(tensorType, (Function) Objects.requireNonNull(function), null);
    }

    public static <NAMETYPE extends Name> Generate<NAMETYPE> bound(TensorType tensorType, ScalarFunction<NAMETYPE> scalarFunction) {
        return new Generate<>(tensorType, null, (ScalarFunction) Objects.requireNonNull(scalarFunction));
    }

    private Generate(TensorType tensorType, Function<List<Long>, Double> function, ScalarFunction<NAMETYPE> scalarFunction) {
        Objects.requireNonNull(tensorType, "The argument tensor type cannot be null");
        validateType(tensorType);
        this.type = tensorType;
        this.freeGenerator = function;
        this.boundGenerator = scalarFunction;
    }

    private void validateType(TensorType tensorType) {
        Iterator<TensorType.Dimension> it = tensorType.dimensions().iterator();
        while (it.hasNext()) {
            if (it.next().type() != TensorType.Dimension.Type.indexedBound) {
                throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
            }
        }
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction<NAMETYPE>> arguments() {
        return (this.boundGenerator == null || !this.boundGenerator.asTensorFunction().isPresent()) ? List.of() : List.of(this.boundGenerator.asTensorFunction().get());
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> list) {
        if (list.size() > 1) {
            throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + list.size());
        }
        if (list.isEmpty()) {
            return this;
        }
        if (list.get(0).asScalarFunction().isEmpty()) {
            throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, but got " + String.valueOf(list.get(0)));
        }
        return new Generate(this.type, null, list.get(0).asScalarFunction().get());
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
        return this;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorType type(TypeContext<NAMETYPE> typeContext) {
        return this.type;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public Tensor evaluate(EvaluationContext<NAMETYPE> evaluationContext) {
        Tensor.Builder of = Tensor.Builder.of(this.type);
        IndexedTensor.Indexes of2 = IndexedTensor.Indexes.of(dimensionSizes(this.type));
        GenerateEvaluationContext generateEvaluationContext = new GenerateEvaluationContext(this.type, evaluationContext);
        for (int i = 0; i < of2.size(); i++) {
            of2.next();
            of.cell(generateEvaluationContext.apply(of2), of2.indexesForReading());
        }
        return of.build();
    }

    private DimensionSizes dimensionSizes(TensorType tensorType) {
        DimensionSizes.Builder builder = new DimensionSizes.Builder(tensorType.dimensions().size());
        for (int i = 0; i < builder.dimensions(); i++) {
            builder.set(i, tensorType.dimensions().get(i).size().get().longValue());
        }
        return builder.build();
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext<NAMETYPE> toStringContext) {
        return String.valueOf(this.type) + "(" + generatorToString(toStringContext) + ")";
    }

    private String generatorToString(ToStringContext<NAMETYPE> toStringContext) {
        return this.freeGenerator != null ? this.freeGenerator.toString() : this.boundGenerator.toString(new GenerateToStringContext(toStringContext));
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public int hashCode() {
        return Objects.hash("generate", this.type, this.freeGenerator, this.boundGenerator);
    }
}
