package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/base/mlp/Layer.class */
public abstract class Layer implements AutoCloseable, Serializable {
    private static final long serialVersionUID = 2;
    protected final int n;
    protected final int p;
    protected final double dropout;
    protected Matrix weight;
    protected double[] bias;
    protected transient ThreadLocal<double[]> output;
    protected transient ThreadLocal<double[]> outputGradient;
    protected transient ThreadLocal<Matrix> weightGradient;
    protected transient ThreadLocal<double[]> biasGradient;
    protected transient ThreadLocal<Matrix> weightGradientMoment1;
    protected transient ThreadLocal<Matrix> weightGradientMoment2;
    protected transient ThreadLocal<double[]> biasGradientMoment1;
    protected transient ThreadLocal<double[]> biasGradientMoment2;
    protected transient ThreadLocal<Matrix> weightUpdate;
    protected transient ThreadLocal<double[]> biasUpdate;
    protected transient ThreadLocal<byte[]> mask;

    /* JADX INFO: Access modifiers changed from: package-private */
    public Layer(int i, double d) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid dropout rate: " + d);
        }
        this.n = i;
        this.p = i;
        this.dropout = d;
        this.output = ThreadLocal.withInitial(() -> {
            return new double[i];
        });
        if (d > 0.0d) {
            this.mask = ThreadLocal.withInitial(() -> {
                return new byte[i];
            });
        }
    }

    public Layer(int i, int i2) {
        this(i, i2, 0.0d);
    }

    public Layer(int i, int i2, double d) {
        this(Matrix.rand(i, i2, -Math.sqrt(6.0d / (i + i2)), Math.sqrt(6.0d / (i + i2))), new double[i], d);
    }

    public Layer(Matrix matrix, double[] dArr) {
        this(matrix, dArr, 0.0d);
    }

    public Layer(Matrix matrix, double[] dArr, double d) {
        if (d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("Invalid dropout rate: " + d);
        }
        this.n = matrix.nrow();
        this.p = matrix.ncol();
        this.weight = matrix;
        this.bias = dArr;
        this.dropout = d;
        init();
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.output != null) {
            this.output.remove();
        }
        if (this.outputGradient != null) {
            this.outputGradient.remove();
        }
        if (this.weightGradient != null) {
            this.weightGradient.remove();
        }
        if (this.biasGradient != null) {
            this.biasGradient.remove();
        }
        if (this.weightGradientMoment1 != null) {
            this.weightGradientMoment1.remove();
        }
        if (this.weightGradientMoment2 != null) {
            this.weightGradientMoment2.remove();
        }
        if (this.biasGradientMoment1 != null) {
            this.biasGradientMoment1.remove();
        }
        if (this.biasGradientMoment2 != null) {
            this.biasGradientMoment2.remove();
        }
        if (this.weightUpdate != null) {
            this.weightUpdate.remove();
        }
        if (this.biasUpdate != null) {
            this.biasUpdate.remove();
        }
        if (this.mask != null) {
            this.mask.remove();
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.defaultReadObject();
        init();
    }

    private void init() {
        this.output = ThreadLocal.withInitial(() -> {
            return new double[this.n];
        });
        this.outputGradient = ThreadLocal.withInitial(() -> {
            return new double[this.n];
        });
        this.weightGradient = ThreadLocal.withInitial(() -> {
            return new Matrix(this.n, this.p);
        });
        this.biasGradient = ThreadLocal.withInitial(() -> {
            return new double[this.n];
        });
        this.weightGradientMoment1 = ThreadLocal.withInitial(() -> {
            return new Matrix(this.n, this.p);
        });
        this.weightGradientMoment2 = ThreadLocal.withInitial(() -> {
            return new Matrix(this.n, this.p);
        });
        this.biasGradientMoment1 = ThreadLocal.withInitial(() -> {
            return new double[this.n];
        });
        this.biasGradientMoment2 = ThreadLocal.withInitial(() -> {
            return new double[this.n];
        });
        this.weightUpdate = ThreadLocal.withInitial(() -> {
            return new Matrix(this.n, this.p);
        });
        this.biasUpdate = ThreadLocal.withInitial(() -> {
            return new double[this.n];
        });
        if (this.dropout > 0.0d) {
            this.mask = ThreadLocal.withInitial(() -> {
                return new byte[this.n];
            });
        }
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output.get();
    }

    public double[] gradient() {
        return this.outputGradient.get();
    }

    public void propagate(double[] dArr) {
        double[] dArr2 = this.output.get();
        System.arraycopy(this.bias, 0, dArr2, 0, this.n);
        this.weight.mv(1.0d, dArr, 1.0d, dArr2);
        transform(dArr2);
    }

    public void propagateDropout() {
        if (this.dropout > 0.0d) {
            double[] dArr = this.output.get();
            byte[] bArr = this.mask.get();
            double d = 1.0d / (1.0d - this.dropout);
            for (int i = 0; i < this.n; i++) {
                byte b = (byte) (MathEx.random() < this.dropout ? 0 : 1);
                bArr[i] = b;
                int i2 = i;
                dArr[i2] = dArr[i2] * b * d;
            }
        }
    }

    public abstract void transform(double[] dArr);

    public abstract void backpropagate(double[] dArr);

    public void backpopagateDropout() {
        if (this.dropout > 0.0d) {
            double[] dArr = this.outputGradient.get();
            byte[] bArr = this.mask.get();
            double d = 1.0d / (1.0d - this.dropout);
            for (int i = 0; i < this.n; i++) {
                int i2 = i;
                dArr[i2] = dArr[i2] * bArr[i] * d;
            }
        }
    }

    public void computeGradientUpdate(double[] dArr, double d, double d2, double d3) {
        double[] dArr2 = this.outputGradient.get();
        if (d2 <= 0.0d || d2 >= 1.0d) {
            this.weight.add(d, dArr2, dArr);
            for (int i = 0; i < this.n; i++) {
                double[] dArr3 = this.bias;
                int i2 = i;
                dArr3[i2] = dArr3[i2] + (d * dArr2[i]);
            }
        } else {
            Matrix matrix = this.weightUpdate.get();
            double[] dArr4 = this.biasUpdate.get();
            matrix.mul(d2);
            matrix.add(d, dArr2, dArr);
            this.weight.add(matrix);
            for (int i3 = 0; i3 < this.n; i3++) {
                double d4 = (d2 * dArr4[i3]) + (d * dArr2[i3]);
                dArr4[i3] = d4;
                double[] dArr5 = this.bias;
                int i4 = i3;
                dArr5[i4] = dArr5[i4] + d4;
            }
        }
        if (d3 <= 0.9d || d3 >= 1.0d) {
            return;
        }
        this.weight.mul(d3);
    }

    public void computeGradient(double[] dArr) {
        double[] dArr2 = this.outputGradient.get();
        Matrix matrix = this.weightGradient.get();
        double[] dArr3 = this.biasGradient.get();
        matrix.add(1.0d, dArr2, dArr);
        for (int i = 0; i < this.n; i++) {
            int i2 = i;
            dArr3[i2] = dArr3[i2] + dArr2[i];
        }
    }

    public void update(int i, double d, double d2, double d3, double d4, double d5) {
        Matrix matrix = this.weightGradient.get();
        double[] dArr = this.biasGradient.get();
        double d6 = d / i;
        if (d4 > 0.0d && d4 < 1.0d) {
            d6 = d;
            matrix.div(i);
            for (int i2 = 0; i2 < this.n; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] / i;
            }
            Matrix matrix2 = this.weightGradientMoment2.get();
            double[] dArr2 = this.biasGradientMoment2.get();
            double d7 = 1.0d - d4;
            for (int i4 = 0; i4 < this.p; i4++) {
                for (int i5 = 0; i5 < this.n; i5++) {
                    matrix2.set(i5, i4, (d4 * matrix2.get(i5, i4)) + (d7 * MathEx.pow2(matrix.get(i5, i4))));
                }
            }
            for (int i6 = 0; i6 < this.n; i6++) {
                dArr2[i6] = (d4 * dArr2[i6]) + (d7 * MathEx.pow2(dArr[i6]));
            }
            for (int i7 = 0; i7 < this.p; i7++) {
                for (int i8 = 0; i8 < this.n; i8++) {
                    matrix.div(i8, i7, Math.sqrt(d5 + matrix2.get(i8, i7)));
                }
            }
            for (int i9 = 0; i9 < this.n; i9++) {
                int i10 = i9;
                dArr[i10] = dArr[i10] / Math.sqrt(d5 + dArr2[i9]);
            }
        }
        if (d2 <= 0.0d || d2 >= 1.0d) {
            this.weight.add(d6, matrix);
            for (int i11 = 0; i11 < this.n; i11++) {
                double[] dArr3 = this.bias;
                int i12 = i11;
                dArr3[i12] = dArr3[i12] + (d6 * dArr[i11]);
            }
        } else {
            Matrix matrix3 = this.weightUpdate.get();
            double[] dArr4 = this.biasUpdate.get();
            matrix3.add(d2, d6, matrix);
            for (int i13 = 0; i13 < this.n; i13++) {
                dArr4[i13] = (d2 * dArr4[i13]) + (d6 * dArr[i13]);
            }
            this.weight.add(matrix3);
            MathEx.add(this.bias, dArr4);
        }
        if (d3 > 0.9d && d3 < 1.0d) {
            this.weight.mul(d3);
        }
        matrix.fill(0.0d);
        Arrays.fill(dArr, 0.0d);
    }

    public static HiddenLayerBuilder builder(String str, int i, double d, double d2) {
        String lowerCase = str.toLowerCase(Locale.ROOT);
        boolean z = -1;
        switch (lowerCase.hashCode()) {
            case -1102672091:
                if (lowerCase.equals("linear")) {
                    z = 3;
                    break;
                }
                break;
            case 3496700:
                if (lowerCase.equals("relu")) {
                    z = false;
                    break;
                }
                break;
            case 3552487:
                if (lowerCase.equals("tanh")) {
                    z = 2;
                    break;
                }
                break;
            case 102845814:
                if (lowerCase.equals("leaky")) {
                    z = 4;
                    break;
                }
                break;
            case 2088248974:
                if (lowerCase.equals("sigmoid")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return rectifier(i, d);
            case true:
                return sigmoid(i, d);
            case true:
                return tanh(i, d);
            case true:
                return linear(i, d);
            case true:
                return Double.isNaN(d2) ? leaky(i, d) : leaky(i, d, d2);
            default:
                throw new IllegalArgumentException("Unsupported activation function: " + str);
        }
    }

    public static LayerBuilder input(int i) {
        return input(i, 0.0d);
    }

    public static LayerBuilder input(int i, double d) {
        return new LayerBuilder(i, d) { // from class: smile.base.mlp.Layer.1
            @Override // smile.base.mlp.LayerBuilder
            public InputLayer build(int i2) {
                return new InputLayer(this.neurons, this.dropout);
            }
        };
    }

    public static HiddenLayerBuilder linear(int i) {
        return linear(i, 0.0d);
    }

    public static HiddenLayerBuilder linear(int i, double d) {
        return new HiddenLayerBuilder(i, d, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int i) {
        return rectifier(i, 0.0d);
    }

    public static HiddenLayerBuilder rectifier(int i, double d) {
        return new HiddenLayerBuilder(i, d, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder leaky(int i) {
        return rectifier(i, 0.0d);
    }

    public static HiddenLayerBuilder leaky(int i, double d) {
        return new HiddenLayerBuilder(i, d, ActivationFunction.leaky());
    }

    public static HiddenLayerBuilder leaky(int i, double d, double d2) {
        return new HiddenLayerBuilder(i, d, ActivationFunction.leaky(d2));
    }

    public static HiddenLayerBuilder sigmoid(int i) {
        return sigmoid(i, 0.0d);
    }

    public static HiddenLayerBuilder sigmoid(int i, double d) {
        return new HiddenLayerBuilder(i, d, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int i) {
        return tanh(i, 0.0d);
    }

    public static HiddenLayerBuilder tanh(int i, double d) {
        return new HiddenLayerBuilder(i, d, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.LIKELIHOOD);
    }

    public static LayerBuilder[] of(int i, int i2, String str) {
        Pattern compile = Pattern.compile(String.format("(\\w+)\\((%s)(,\\s*(%s))?(,\\s*(%s))?\\)", "[-+]?\\d{1,9}", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?"));
        String[] split = str.split("\\|");
        ArrayList arrayList = new ArrayList();
        for (int i3 = 0; i3 < split.length; i3++) {
            Matcher matcher = compile.matcher(split[i3]);
            if (!matcher.matches()) {
                throw new IllegalArgumentException("Invalid layer: " + split[i3]);
            }
            String group = matcher.group(1);
            int parseInt = Integer.parseInt(matcher.group(2));
            double parseDouble = matcher.group(3) != null ? Double.parseDouble(matcher.group(4)) : 0.0d;
            double parseDouble2 = matcher.group(5) != null ? Double.parseDouble(matcher.group(6)) : Double.NaN;
            if (i3 != 0) {
                arrayList.add(builder(group, parseInt, parseDouble, parseDouble2));
            } else if (group.equalsIgnoreCase("input")) {
                arrayList.add(input(parseInt, parseDouble));
            } else {
                arrayList.add(input(i2));
                arrayList.add(builder(group, parseInt, parseDouble, parseDouble2));
            }
        }
        if (i < 2) {
            arrayList.add(mse(1, OutputFunction.LINEAR));
        } else if (i == 2) {
            arrayList.add(mle(1, OutputFunction.SIGMOID));
        } else {
            arrayList.add(mle(i, OutputFunction.SOFTMAX));
        }
        return (LayerBuilder[]) arrayList.toArray(new LayerBuilder[0]);
    }
}
