package maths.errorfunctions;

import base.CommonConstants;
import datasets.VectorDouble;
import datastructs.I2DDataSet;
import datastructs.IVector;
import maths.functions.IVectorRealFunction;

/* loaded from: input_file:maths/errorfunctions/LogisticSSEVectorFunction.class */
public class LogisticSSEVectorFunction implements IVectorErrorRealFunction {
    private IVectorRealFunction<IVector<Double>> hypothesis;

    public LogisticSSEVectorFunction(IVectorRealFunction<IVector<Double>> iVectorRealFunction) {
        if (iVectorRealFunction == null) {
            throw new IllegalArgumentException("Hypothesis function cannot be null");
        }
        this.hypothesis = iVectorRealFunction;
    }

    @Override // maths.errorfunctions.IVectorErrorRealFunction
    public <DataSetType extends I2DDataSet> double evaluate(DataSetType datasettype, VectorDouble vectorDouble) {
        if (datasettype.m() != vectorDouble.size()) {
            throw new IllegalArgumentException("Invalid number of data points and labels vector size");
        }
        double d = 0.0d;
        for (int i = 0; i < datasettype.m(); i++) {
            VectorDouble vectorDouble2 = (VectorDouble) datasettype.getRow(i);
            double doubleValue = vectorDouble.get(i).doubleValue();
            double doubleValue2 = this.hypothesis.evaluate(vectorDouble2).doubleValue();
            if (Math.abs(doubleValue2) - 1.0d < CommonConstants.getTol()) {
                if (doubleValue != 1.0d) {
                    d += 1.0d;
                }
            } else if (Math.abs(doubleValue2) >= CommonConstants.getTol()) {
                d += (doubleValue * Math.log(doubleValue2)) + ((1.0d - doubleValue) * Math.log(1.0d - doubleValue2));
            } else if (doubleValue > CommonConstants.getTol()) {
                d += 1.0d;
            }
        }
        return -d;
    }

    @Override // maths.errorfunctions.IVectorErrorRealFunction
    public <DataSetType extends I2DDataSet> VectorDouble gradients(DataSetType datasettype, VectorDouble vectorDouble) {
        VectorDouble vectorDouble2 = new VectorDouble(this.hypothesis.numCoeffs(), 0.0d);
        for (int i = 0; i < datasettype.m(); i++) {
            VectorDouble vectorDouble3 = (VectorDouble) datasettype.getRow(i);
            double doubleValue = vectorDouble.get(i).doubleValue() - this.hypothesis.evaluate(vectorDouble3).doubleValue();
            IVector<Double> coeffGradients = this.hypothesis.coeffGradients(vectorDouble3);
            for (int i2 = 0; i2 < this.hypothesis.numCoeffs(); i2++) {
                vectorDouble2.add(i2, Double.valueOf((-2.0d) * doubleValue * coeffGradients.get(i2).doubleValue()));
            }
        }
        return vectorDouble2;
    }
}
