package smile.manifold;

import java.io.Serializable;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.invoke.SerializedLambda;
import java.lang.runtime.ObjectMethods;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;
import smile.util.AlgoStatus;
import smile.util.IterativeAlgorithmController;

/* loaded from: input_file:smile/manifold/TSNE.class */
public final class TSNE extends Record implements Serializable {
    private final double cost;
    private final double[][] coordinates;
    private static final long serialVersionUID = 3;
    private static final Logger logger = LoggerFactory.getLogger(TSNE.class);

    /* loaded from: input_file:smile/manifold/TSNE$Options.class */
    public static final class Options extends Record {
        private final int d;
        private final double perplexity;
        private final double eta;
        private final double earlyExaggeration;
        private final int maxIter;
        private final int maxIterWithoutProgress;
        private final double tol;
        private final double momentum;
        private final double finalMomentum;
        private final int momentumSwitchIter;
        private final double minGain;
        private final IterativeAlgorithmController<AlgoStatus> controller;

        public Options(int i, double d, double d2, double d3, int i2, int i3, double d4, double d5, double d6, int i4, double d7, IterativeAlgorithmController<AlgoStatus> iterativeAlgorithmController) {
            if (i < 2) {
                throw new IllegalArgumentException("Invalid dimension of feature space: " + i);
            }
            if (d < 2.0d) {
                throw new IllegalArgumentException("Invalid perplexity: " + d);
            }
            if (d2 <= 0.0d) {
                throw new IllegalArgumentException("Invalid learning rate: " + d2);
            }
            if (d3 <= 0.0d) {
                throw new IllegalArgumentException("Invalid early exaggeration: " + d3);
            }
            if (i2 < 250) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
            }
            if (i3 < 50 || i3 > i2) {
                throw new IllegalArgumentException("Invalid maximum number of iterations without progress: " + i3);
            }
            if (d4 <= 0.0d) {
                throw new IllegalArgumentException("Invalid tolerance: " + d4);
            }
            if (d5 <= 0.0d) {
                throw new IllegalArgumentException("Invalid momentum: " + d5);
            }
            if (d6 <= 0.0d) {
                throw new IllegalArgumentException("Invalid final momentum: " + d6);
            }
            if (i4 <= 0 || i4 >= i2) {
                throw new IllegalArgumentException("Invalid learning rate: " + i4);
            }
            if (d7 <= 0.0d) {
                throw new IllegalArgumentException("Invalid minimum gain: " + d7);
            }
            this.d = i;
            this.perplexity = d;
            this.eta = d2;
            this.earlyExaggeration = d3;
            this.maxIter = i2;
            this.maxIterWithoutProgress = i3;
            this.tol = d4;
            this.momentum = d5;
            this.finalMomentum = d6;
            this.momentumSwitchIter = i4;
            this.minGain = d7;
            this.controller = iterativeAlgorithmController;
        }

        public Options(int i, double d, double d2, double d3, int i2) {
            this(i, d, d2, d3, i2, 50, 1.0E-7d, 0.5d, 0.8d, 250, 0.01d, null);
        }

        public Properties toProperties() {
            Properties properties = new Properties();
            properties.setProperty("smile.t_sne.d", Integer.toString(this.d));
            properties.setProperty("smile.t_sne.perplexity", Double.toString(this.perplexity));
            properties.setProperty("smile.t_sne.eta", Double.toString(this.eta));
            properties.setProperty("smile.t_sne.early_exaggeration", Double.toString(this.earlyExaggeration));
            properties.setProperty("smile.t_sne.iterations", Integer.toString(this.maxIter));
            properties.setProperty("smile.t_sne.max_iterations_without_progress", Integer.toString(this.maxIterWithoutProgress));
            properties.setProperty("smile.t_sne.tolerance", Double.toString(this.tol));
            properties.setProperty("smile.t_sne.momentum", Double.toString(this.momentum));
            properties.setProperty("smile.t_sne.final_momentum", Double.toString(this.finalMomentum));
            properties.setProperty("smile.t_sne.momentum_switch", Integer.toString(this.momentumSwitchIter));
            properties.setProperty("smile.t_sne.min_gain", Double.toString(this.minGain));
            return properties;
        }

        public static Options of(Properties properties) {
            return new Options(Integer.parseInt(properties.getProperty("smile.t_sne.d", "2")), Double.parseDouble(properties.getProperty("smile.t_sne.perplexity", "20")), Double.parseDouble(properties.getProperty("smile.t_sne.eta", "200")), Double.parseDouble(properties.getProperty("smile.t_sne.early_exaggeration")), Integer.parseInt(properties.getProperty("smile.t_sne.iterations", "1000")), Integer.parseInt(properties.getProperty("smile.t_sne.max_iterations_without_progress", "50")), Double.parseDouble(properties.getProperty("smile.t_sne.tolerance", "1E-7")), Double.parseDouble(properties.getProperty("smile.t_sne.momentum")), Double.parseDouble(properties.getProperty("smile.t_sne.final_momentum")), Integer.parseInt(properties.getProperty("smile.t_sne.momentum_switch")), Double.parseDouble(properties.getProperty("smile.t_sne.momentum_switch")), null);
        }

        @Override // java.lang.Record
        public final String toString() {
            return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, Options.class), Options.class, "d;perplexity;eta;earlyExaggeration;maxIter;maxIterWithoutProgress;tol;momentum;finalMomentum;momentumSwitchIter;minGain;controller", "FIELD:Lsmile/manifold/TSNE$Options;->d:I", "FIELD:Lsmile/manifold/TSNE$Options;->perplexity:D", "FIELD:Lsmile/manifold/TSNE$Options;->eta:D", "FIELD:Lsmile/manifold/TSNE$Options;->earlyExaggeration:D", "FIELD:Lsmile/manifold/TSNE$Options;->maxIter:I", "FIELD:Lsmile/manifold/TSNE$Options;->maxIterWithoutProgress:I", "FIELD:Lsmile/manifold/TSNE$Options;->tol:D", "FIELD:Lsmile/manifold/TSNE$Options;->momentum:D", "FIELD:Lsmile/manifold/TSNE$Options;->finalMomentum:D", "FIELD:Lsmile/manifold/TSNE$Options;->momentumSwitchIter:I", "FIELD:Lsmile/manifold/TSNE$Options;->minGain:D", "FIELD:Lsmile/manifold/TSNE$Options;->controller:Lsmile/util/IterativeAlgorithmController;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final int hashCode() {
            return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, Options.class), Options.class, "d;perplexity;eta;earlyExaggeration;maxIter;maxIterWithoutProgress;tol;momentum;finalMomentum;momentumSwitchIter;minGain;controller", "FIELD:Lsmile/manifold/TSNE$Options;->d:I", "FIELD:Lsmile/manifold/TSNE$Options;->perplexity:D", "FIELD:Lsmile/manifold/TSNE$Options;->eta:D", "FIELD:Lsmile/manifold/TSNE$Options;->earlyExaggeration:D", "FIELD:Lsmile/manifold/TSNE$Options;->maxIter:I", "FIELD:Lsmile/manifold/TSNE$Options;->maxIterWithoutProgress:I", "FIELD:Lsmile/manifold/TSNE$Options;->tol:D", "FIELD:Lsmile/manifold/TSNE$Options;->momentum:D", "FIELD:Lsmile/manifold/TSNE$Options;->finalMomentum:D", "FIELD:Lsmile/manifold/TSNE$Options;->momentumSwitchIter:I", "FIELD:Lsmile/manifold/TSNE$Options;->minGain:D", "FIELD:Lsmile/manifold/TSNE$Options;->controller:Lsmile/util/IterativeAlgorithmController;").dynamicInvoker().invoke(this) /* invoke-custom */;
        }

        @Override // java.lang.Record
        public final boolean equals(Object obj) {
            return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, Options.class, Object.class), Options.class, "d;perplexity;eta;earlyExaggeration;maxIter;maxIterWithoutProgress;tol;momentum;finalMomentum;momentumSwitchIter;minGain;controller", "FIELD:Lsmile/manifold/TSNE$Options;->d:I", "FIELD:Lsmile/manifold/TSNE$Options;->perplexity:D", "FIELD:Lsmile/manifold/TSNE$Options;->eta:D", "FIELD:Lsmile/manifold/TSNE$Options;->earlyExaggeration:D", "FIELD:Lsmile/manifold/TSNE$Options;->maxIter:I", "FIELD:Lsmile/manifold/TSNE$Options;->maxIterWithoutProgress:I", "FIELD:Lsmile/manifold/TSNE$Options;->tol:D", "FIELD:Lsmile/manifold/TSNE$Options;->momentum:D", "FIELD:Lsmile/manifold/TSNE$Options;->finalMomentum:D", "FIELD:Lsmile/manifold/TSNE$Options;->momentumSwitchIter:I", "FIELD:Lsmile/manifold/TSNE$Options;->minGain:D", "FIELD:Lsmile/manifold/TSNE$Options;->controller:Lsmile/util/IterativeAlgorithmController;").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
        }

        public int d() {
            return this.d;
        }

        public double perplexity() {
            return this.perplexity;
        }

        public double eta() {
            return this.eta;
        }

        public double earlyExaggeration() {
            return this.earlyExaggeration;
        }

        public int maxIter() {
            return this.maxIter;
        }

        public int maxIterWithoutProgress() {
            return this.maxIterWithoutProgress;
        }

        public double tol() {
            return this.tol;
        }

        public double momentum() {
            return this.momentum;
        }

        public double finalMomentum() {
            return this.finalMomentum;
        }

        public int momentumSwitchIter() {
            return this.momentumSwitchIter;
        }

        public double minGain() {
            return this.minGain;
        }

        public IterativeAlgorithmController<AlgoStatus> controller() {
            return this.controller;
        }
    }

    public TSNE(double d, double[][] dArr) {
        this.cost = d;
        this.coordinates = dArr;
    }

    public static TSNE fit(double[][] dArr) {
        return fit(dArr, new Options(2, 20.0d, 200.0d, 12.0d, 1000));
    }

    public static TSNE fit(double[][] dArr, Options options) {
        double[][] dArr2;
        double d = options.eta;
        int length = dArr.length;
        int i = options.d;
        if (dArr.length == dArr[0].length) {
            dArr2 = dArr;
        } else {
            dArr2 = new double[length][length];
            MathEx.pdist(dArr, dArr2, MathEx::squaredDistance);
        }
        double[][] dArr3 = new double[length][i];
        double[][] dArr4 = new double[length][i];
        GaussianDistribution gaussianDistribution = new GaussianDistribution(0.0d, 1.0E-4d);
        for (int i2 = 0; i2 < length; i2++) {
            Arrays.fill(dArr4[i2], 1.0d);
            double[] dArr5 = dArr3[i2];
            for (int i3 = 0; i3 < i; i3++) {
                dArr5[i3] = gaussianDistribution.rand();
            }
        }
        double[][] expd = expd(dArr2, options.perplexity, 0.001d);
        double[][] dArr6 = new double[length][length];
        double[][] dArr7 = new double[length][i];
        double[][] dArr8 = new double[length][i];
        double d2 = 2 * length;
        for (int i4 = 0; i4 < length; i4++) {
            double[] dArr9 = expd[i4];
            for (int i5 = 0; i5 < i4; i5++) {
                double d3 = (12.0d * (dArr9[i5] + expd[i5][i4])) / d2;
                if (Double.isNaN(d3) || d3 < 1.0E-16d) {
                    d3 = 1.0E-16d;
                }
                dArr9[i5] = d3;
                expd[i5][i4] = d3;
            }
        }
        double d4 = Double.MAX_VALUE;
        double d5 = Double.MAX_VALUE;
        int i6 = 0;
        double d6 = options.momentum;
        int i7 = 1;
        while (true) {
            if (i7 > options.maxIter) {
                break;
            }
            double computeQ = computeQ(dArr3, dArr6);
            IntStream.range(0, length).parallel().forEach(i8 -> {
                sne(i8, dArr3, expd, dArr6, dArr4, dArr7[i8], dArr8[i8], computeQ, options.minGain);
            });
            double d7 = d6;
            double orElse = IntStream.range(0, length).parallel().mapToDouble(i9 -> {
                double[] dArr10 = dArr3[i9];
                double[] dArr11 = dArr7[i9];
                double[] dArr12 = dArr8[i9];
                double[] dArr13 = dArr4[i9];
                double d8 = 0.0d;
                for (int i9 = 0; i9 < i; i9++) {
                    dArr11[i9] = (d7 * dArr11[i9]) - ((d * dArr13[i9]) * dArr12[i9]);
                    int i10 = i9;
                    dArr10[i10] = dArr10[i10] + dArr11[i9];
                    d8 = Math.max(d8, Math.abs(dArr11[i9] * dArr13[i9]));
                }
                return d8;
            }).max().orElse(0.0d);
            if (i7 == options.momentumSwitchIter) {
                d6 = options.finalMomentum;
                for (int i10 = 0; i10 < length; i10++) {
                    double[] dArr10 = expd[i10];
                    for (int i11 = 0; i11 < length; i11++) {
                        int i12 = i11;
                        dArr10[i12] = dArr10[i12] / options.earlyExaggeration;
                    }
                }
            }
            if (i7 % 10 == 0 || i7 == options.maxIter) {
                d4 = computeCost(expd, dArr6, computeQ);
                logger.info("Iteration {}: error = {}", Integer.valueOf(i7), Double.valueOf(d4));
                if (d4 < d5) {
                    d5 = d4;
                    i6 = i7;
                }
                if (i7 > options.momentumSwitchIter) {
                    if (i7 - i6 > options.maxIterWithoutProgress) {
                        logger.info("Iteration {}: did not make any progress in last {} episodes. Finished", Integer.valueOf(i7), Integer.valueOf(options.maxIterWithoutProgress));
                        break;
                    }
                    if (orElse < options.tol) {
                        logger.info("Iteration {}: gradient norm = {}. Finished", Integer.valueOf(i7), Double.valueOf(orElse));
                        break;
                    }
                }
                if (options.controller != null) {
                    options.controller.submit(new AlgoStatus(i7, d4));
                    if (options.controller.isInterrupted()) {
                        break;
                    }
                } else {
                    continue;
                }
            }
            i7++;
        }
        double[] colMeans = MathEx.colMeans(dArr3);
        IntStream.range(0, length).parallel().forEach(i13 -> {
            double[] dArr11 = dArr3[i13];
            for (int i13 = 0; i13 < i; i13++) {
                int i14 = i13;
                dArr11[i14] = dArr11[i14] - colMeans[i13];
            }
        });
        return new TSNE(d4, dArr3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void sne(int i, double[][] dArr, double[][] dArr2, double[][] dArr3, double[][] dArr4, double[] dArr5, double[] dArr6, double d, double d2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[] dArr7 = dArr[i];
        double[] dArr8 = dArr2[i];
        double[] dArr9 = dArr3[i];
        double[] dArr10 = dArr4[i];
        Arrays.fill(dArr6, 0.0d);
        for (int i2 = 0; i2 < length; i2++) {
            if (i != i2) {
                double[] dArr11 = dArr[i2];
                double d3 = dArr9[i2];
                double d4 = (dArr8[i2] - (d3 / d)) * d3;
                for (int i3 = 0; i3 < length2; i3++) {
                    int i4 = i3;
                    dArr6[i4] = dArr6[i4] + (4.0d * (dArr7[i3] - dArr11[i3]) * d4);
                }
            }
        }
        for (int i5 = 0; i5 < length2; i5++) {
            dArr10[i5] = Math.signum(dArr6[i5]) != Math.signum(dArr5[i5]) ? dArr10[i5] + 0.2d : dArr10[i5] * 0.8d;
            if (dArr10[i5] < d2) {
                dArr10[i5] = d2;
            }
        }
    }

    private static double[][] expd(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] dArr2 = new double[length][length];
        double[] rowSums = MathEx.rowSums(dArr);
        IntStream.range(0, length).parallel().forEach(i -> {
            double log2 = MathEx.log2(d);
            double[] dArr3 = dArr2[i];
            double[] dArr4 = dArr[i];
            double sqrt = Math.sqrt((length - 1) / rowSums[i]);
            double d3 = 0.0d;
            double d4 = Double.POSITIVE_INFINITY;
            logger.debug("initial beta[{}] = {}", Integer.valueOf(i), Double.valueOf(sqrt));
            double d5 = Double.MAX_VALUE;
            for (int i = 0; Math.abs(d5) > d2 && i < 50; i++) {
                double d6 = 0.0d;
                double d7 = 0.0d;
                for (int i2 = 0; i2 < length; i2++) {
                    double d8 = sqrt * dArr4[i2];
                    double exp = Math.exp(-d8);
                    dArr3[i2] = exp;
                    d6 += exp;
                    d7 += exp * d8;
                }
                dArr3[i] = 0.0d;
                double d9 = d6 - 1.0d;
                double log22 = MathEx.log2(d9) + (d7 / d9);
                d5 = log22 - log2;
                if (Math.abs(d5) <= d2) {
                    for (int i3 = 0; i3 < length; i3++) {
                        int i4 = i3;
                        dArr3[i4] = dArr3[i4] / d9;
                    }
                } else if (d5 > 0.0d) {
                    d3 = sqrt;
                    sqrt = Double.isInfinite(d4) ? sqrt * 2.0d : (sqrt + d4) / 2.0d;
                } else {
                    d4 = sqrt;
                    sqrt = (sqrt + d3) / 2.0d;
                }
                logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", new Object[]{Double.valueOf(d5), Integer.valueOf(i), Double.valueOf(sqrt), Double.valueOf(log22), Double.valueOf(log2)});
            }
        });
        return dArr2;
    }

    private static double computeQ(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        return MathEx.sum(IntStream.range(0, length).parallel().mapToDouble(i -> {
            double[] dArr3 = dArr[i];
            double[] dArr4 = dArr2[i];
            double d = 0.0d;
            for (int i = 0; i < length; i++) {
                double squaredDistance = 1.0d / (1.0d + MathEx.squaredDistance(dArr3, dArr[i]));
                dArr4[i] = squaredDistance;
                d += squaredDistance;
            }
            return d;
        }).toArray());
    }

    private static double computeCost(double[][] dArr, double[][] dArr2, double d) {
        return 2.0d * IntStream.range(0, dArr2.length).parallel().mapToDouble(i -> {
            double[] dArr3 = dArr[i];
            double[] dArr4 = dArr2[i];
            double d2 = 0.0d;
            for (int i = 0; i < i; i++) {
                double d3 = dArr3[i];
                double d4 = dArr4[i] / d;
                if (Double.isNaN(d4) || d4 < 1.0E-16d) {
                    d4 = 1.0E-16d;
                }
                d2 += d3 * MathEx.log2(d3 / d4);
            }
            return d2;
        }).sum();
    }

    @Override // java.lang.Record
    public final String toString() {
        return (String) ObjectMethods.bootstrap(MethodHandles.lookup(), "toString", MethodType.methodType(String.class, TSNE.class), TSNE.class, "cost;coordinates", "FIELD:Lsmile/manifold/TSNE;->cost:D", "FIELD:Lsmile/manifold/TSNE;->coordinates:[[D").dynamicInvoker().invoke(this) /* invoke-custom */;
    }

    @Override // java.lang.Record
    public final int hashCode() {
        return (int) ObjectMethods.bootstrap(MethodHandles.lookup(), "hashCode", MethodType.methodType(Integer.TYPE, TSNE.class), TSNE.class, "cost;coordinates", "FIELD:Lsmile/manifold/TSNE;->cost:D", "FIELD:Lsmile/manifold/TSNE;->coordinates:[[D").dynamicInvoker().invoke(this) /* invoke-custom */;
    }

    @Override // java.lang.Record
    public final boolean equals(Object obj) {
        return (boolean) ObjectMethods.bootstrap(MethodHandles.lookup(), "equals", MethodType.methodType(Boolean.TYPE, TSNE.class, Object.class), TSNE.class, "cost;coordinates", "FIELD:Lsmile/manifold/TSNE;->cost:D", "FIELD:Lsmile/manifold/TSNE;->coordinates:[[D").dynamicInvoker().invoke(this, obj) /* invoke-custom */;
    }

    public double cost() {
        return this.cost;
    }

    public double[][] coordinates() {
        return this.coordinates;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1421810692:
                if (implMethodName.equals("squaredDistance")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/MathEx") && serializedLambda.getImplMethodSignature().equals("([D[D)D")) {
                    return MathEx::squaredDistance;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
