package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.functions.Reduce;
import java.util.List;

@Beta
/* loaded from: input_file:com/yahoo/tensor/functions/Matmul.class */
public class Matmul extends CompositeTensorFunction {
    private final TensorFunction argument1;
    private final TensorFunction argument2;
    private final String dimension;

    public Matmul(TensorFunction tensorFunction, TensorFunction tensorFunction2, String str) {
        this.argument1 = tensorFunction;
        this.argument2 = tensorFunction2;
        this.dimension = str;
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public List<TensorFunction> functionArguments() {
        return ImmutableList.of(this.argument1, this.argument2);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public TensorFunction replaceArguments(List<TensorFunction> list) {
        if (list.size() != 2) {
            throw new IllegalArgumentException("Matmul must have 2 arguments, got " + list.size());
        }
        return new Matmul(list.get(0), list.get(1), this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public PrimitiveTensorFunction toPrimitive() {
        return new Reduce(new Join(this.argument1.toPrimitive(), this.argument2.toPrimitive(), ScalarFunctions.multiply()), Reduce.Aggregator.sum, this.dimension);
    }

    @Override // com.yahoo.tensor.functions.TensorFunction
    public String toString(ToStringContext toStringContext) {
        return "matmul(" + this.argument1.toString(toStringContext) + ", " + this.argument2.toString(toStringContext) + ", " + this.dimension + ")";
    }
}
