package com.jujutsu.tsne;

import com.jujutsu.tsne.TSne;
import com.jujutsu.utils.MatrixOps;

/* loaded from: input_file:com/jujutsu/tsne/SimpleTSne.class */
public class SimpleTSne implements TSne {
    MatrixOps mo = new MatrixOps();
    protected volatile boolean abort = false;

    @Override // com.jujutsu.tsne.TSne
    public double[][] tsne(TSneConfiguration tSneConfiguration) {
        double[][] xin = tSneConfiguration.getXin();
        int outputDims = tSneConfiguration.getOutputDims();
        int initialDims = tSneConfiguration.getInitialDims();
        double perplexity = tSneConfiguration.getPerplexity();
        int maxIter = tSneConfiguration.getMaxIter();
        boolean usePca = tSneConfiguration.usePca();
        String simpleName = getClass().getSimpleName();
        System.out.println("X:Shape is = " + xin.length + " x " + xin[0].length);
        System.out.println("Running " + simpleName + ".");
        if (usePca && xin[0].length > initialDims && initialDims > 0) {
            xin = new PrincipalComponentAnalysis().pca(xin, initialDims);
            System.out.println("X:Shape after PCA is = " + xin.length + " x " + xin[0].length);
        }
        int length = xin.length;
        double[][] rnorm = MatrixOps.rnorm(length, outputDims);
        MatrixOps.fillMatrix(length, outputDims, 0.0d);
        double[][] fillMatrix = MatrixOps.fillMatrix(length, outputDims, 0.0d);
        double[][] fillMatrix2 = MatrixOps.fillMatrix(length, outputDims, 1.0d);
        double[][] dArr = x2p(xin, 1.0E-5d, perplexity).P;
        double[][] plus = MatrixOps.plus(dArr, this.mo.transpose(dArr));
        double[][] maximum = MatrixOps.maximum(MatrixOps.scalarMult(MatrixOps.scalarDivide(plus, MatrixOps.sum(plus)), 4.0d), 1.0E-12d);
        System.out.println("Y:Shape is = " + rnorm.length + " x " + rnorm[0].length);
        int i = 0;
        while (i < maxIter && !this.abort) {
            double[][] transpose = this.mo.transpose(MatrixOps.sum(MatrixOps.square(rnorm), 1));
            double[][] scalarInverse = MatrixOps.scalarInverse(MatrixOps.scalarPlus(MatrixOps.addRowVector(this.mo.transpose(MatrixOps.addRowVector(MatrixOps.scalarMult(MatrixOps.times(rnorm, this.mo.transpose(rnorm)), -2.0d), transpose)), transpose), 1.0d));
            MatrixOps.assignAtIndex(scalarInverse, MatrixOps.range(length), MatrixOps.range(length), 0.0d);
            double[][] maximum2 = MatrixOps.maximum(MatrixOps.scalarDivide(scalarInverse, MatrixOps.sum(scalarInverse)), 1.0E-12d);
            double[][] scalarMultiply = this.mo.scalarMultiply(this.mo.minus(maximum, maximum2), scalarInverse);
            double[][] scalarMult = MatrixOps.scalarMult(MatrixOps.times(this.mo.minus(MatrixOps.diag(MatrixOps.sum(scalarMultiply, 1)), scalarMultiply), rnorm), 4.0d);
            double d = i < 20 ? 0.5d : 0.8d;
            fillMatrix2 = MatrixOps.plus(this.mo.scalarMultiply(MatrixOps.scalarPlus(fillMatrix2, 0.2d), MatrixOps.abs(MatrixOps.negate(MatrixOps.equal(MatrixOps.biggerThan(scalarMult, 0.0d), MatrixOps.biggerThan(fillMatrix, 0.0d))))), this.mo.scalarMultiply(MatrixOps.scalarMult(fillMatrix2, 0.8d), MatrixOps.abs(MatrixOps.equal(MatrixOps.biggerThan(scalarMult, 0.0d), MatrixOps.biggerThan(fillMatrix, 0.0d)))));
            MatrixOps.assignAllLessThan(fillMatrix2, 0.01d, 0.01d);
            fillMatrix = this.mo.minus(MatrixOps.scalarMult(fillMatrix, d), MatrixOps.scalarMult(this.mo.scalarMultiply(fillMatrix2, scalarMult), 500));
            double[][] plus2 = MatrixOps.plus(rnorm, fillMatrix);
            rnorm = this.mo.minus(plus2, MatrixOps.tile(MatrixOps.mean(plus2, 0), length, 1));
            if (i % 100 == 0) {
                System.out.println("Iteration " + (i + 1) + ": error is " + MatrixOps.sum(this.mo.scalarMultiply(maximum, MatrixOps.replaceNaN(MatrixOps.log(MatrixOps.scalarDivide(maximum, maximum2)), 0.0d))));
            } else if ((i + 1) % 10 == 0) {
                System.out.println("Iteration " + (i + 1));
            }
            if (i == 100) {
                maximum = MatrixOps.scalarDivide(maximum, 4.0d);
            }
            i++;
        }
        return rnorm;
    }

    public TSne.R Hbeta(double[][] dArr, double d) {
        double[][] exp = MatrixOps.exp(MatrixOps.scalarMult(MatrixOps.scalarMult(dArr, d), -1.0d));
        double sum = MatrixOps.sum(exp);
        double log = Math.log(sum) + ((d * MatrixOps.sum(this.mo.scalarMultiply(dArr, exp))) / sum);
        double[][] scalarDivide = MatrixOps.scalarDivide(exp, sum);
        TSne.R r = new TSne.R();
        r.H = log;
        r.P = scalarDivide;
        return r;
    }

    public TSne.R x2p(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] sum = MatrixOps.sum(MatrixOps.square(dArr), 1);
        double[][] addRowVector = MatrixOps.addRowVector(MatrixOps.addColumnVector(this.mo.transpose(MatrixOps.scalarMult(MatrixOps.times(dArr, this.mo.transpose(dArr)), -2.0d)), sum), this.mo.transpose(sum));
        double[][] fillMatrix = MatrixOps.fillMatrix(length, length, 0.0d);
        double[] dArr2 = MatrixOps.fillMatrix(length, length, 1.0d)[0];
        double log = Math.log(d2);
        System.out.println("Starting x2p...");
        for (int i = 0; i < length; i++) {
            if (i % 500 == 0) {
                System.out.println("Computing P-values for point " + i + " of " + length + "...");
            }
            double d3 = Double.NEGATIVE_INFINITY;
            double d4 = Double.POSITIVE_INFINITY;
            double[][] valuesFromRow = MatrixOps.getValuesFromRow(addRowVector, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)));
            TSne.R Hbeta = Hbeta(valuesFromRow, dArr2[i]);
            double d5 = Hbeta.H;
            double[][] dArr3 = Hbeta.P;
            double d6 = d5 - log;
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (Math.abs(d6) > d && i3 < 50) {
                    if (d6 > 0.0d) {
                        d3 = dArr2[i];
                        if (Double.isInfinite(d4)) {
                            dArr2[i] = dArr2[i] * 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d4) / 2.0d;
                        }
                    } else {
                        d4 = dArr2[i];
                        if (Double.isInfinite(d3)) {
                            dArr2[i] = dArr2[i] / 2.0d;
                        } else {
                            dArr2[i] = (dArr2[i] + d3) / 2.0d;
                        }
                    }
                    TSne.R Hbeta2 = Hbeta(valuesFromRow, dArr2[i]);
                    double d7 = Hbeta2.H;
                    dArr3 = Hbeta2.P;
                    d6 = d7 - log;
                    i2 = i3 + 1;
                }
            }
            MatrixOps.assignValuesToRow(fillMatrix, i, MatrixOps.concatenate(MatrixOps.range(0, i), MatrixOps.range(i + 1, length)), dArr3[0]);
        }
        TSne.R r = new TSne.R();
        r.P = fillMatrix;
        r.beta = dArr2;
        System.out.println("Mean value of sigma: " + MatrixOps.mean(MatrixOps.sqrt(MatrixOps.scalarInverse(dArr2))));
        return r;
    }

    @Override // com.jujutsu.tsne.TSne
    public void abort() {
        this.abort = true;
    }
}
