package ml.regression;

import datasets.DenseMatrixSet;
import datasets.VectorDouble;
import datastructs.IVector;
import maths.functions.IVectorRealFunction;
import optimization.ISupervisedOptimizer;

/* loaded from: input_file:ml/regression/RegressorBase.class */
public class RegressorBase<DataSetType extends DenseMatrixSet<Double>, HypothesisType extends IVectorRealFunction<IVector<Double>>> {
    protected HypothesisType hypothesisType;

    public <OutputType> OutputType train(DataSetType datasettype, VectorDouble vectorDouble, ISupervisedOptimizer iSupervisedOptimizer) {
        return (OutputType) iSupervisedOptimizer.optimize(datasettype, vectorDouble, this.hypothesisType);
    }

    public double predict(VectorDouble vectorDouble) {
        return ((Double) this.hypothesisType.evaluate(vectorDouble)).doubleValue();
    }

    public VectorDouble predict(DataSetType datasettype) {
        VectorDouble vectorDouble = new VectorDouble(datasettype.m(), 0.0d);
        for (int i = 0; i < datasettype.m(); i++) {
            vectorDouble.set(i, (Double) this.hypothesisType.evaluate(datasettype.getRow(i)));
        }
        return vectorDouble;
    }

    public VectorDouble getErrors(DataSetType datasettype, VectorDouble vectorDouble) {
        if (vectorDouble.size() != datasettype.m()) {
            throw new IllegalArgumentException("Dataset number of rows: " + datasettype.m() + " not equal to " + vectorDouble.size());
        }
        VectorDouble vectorDouble2 = new VectorDouble(vectorDouble.size(), 0.0d);
        for (int i = 0; i < datasettype.m(); i++) {
            vectorDouble2.set(i, vectorDouble.get(i).doubleValue() - ((Double) this.hypothesisType.evaluate(datasettype.getRow(i))).doubleValue());
        }
        return vectorDouble2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RegressorBase(HypothesisType hypothesistype) {
        this.hypothesisType = hypothesistype;
    }
}
