package com.jujutsu.tsne.demos;

import com.itextpdf.text.xml.xmp.DublinCoreSchema;
import com.jujutsu.tsne.MemOptimizedTSne;
import com.jujutsu.utils.MatrixOps;
import com.jujutsu.utils.MatrixUtils;
import com.jujutsu.utils.TSneUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import joinery.DataFrame;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.apache.commons.cli.PosixParser;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.math.plot.FrameView;
import org.math.plot.Plot2DPanel;
import org.math.plot.PlotPanel;
import org.math.plot.plotObjects.Base;
import org.math.plot.plots.ColoredScatterPlot;
import org.math.plot.plots.ScatterPlot;
import org.springframework.beans.propertyeditors.StringArrayPropertyEditor;

/* loaded from: input_file:com/jujutsu/tsne/demos/TSneCsv.class */
public class TSneCsv {
    static int initial_dims = -1;
    static int output_dims = 2;
    static double perplexity = 20.0d;
    static boolean hasLabels = true;
    static boolean scale_log = false;
    static boolean normalize = false;
    static boolean addNoise = false;
    static boolean subSample = false;
    static boolean hasHeader = true;
    static boolean doPlot = true;
    static boolean doSave = false;
    static boolean transpose_after = false;
    static String output_fn = null;
    static String label_fn = null;
    static String naString = null;
    static int iterations = 2000;
    static int label_col_no = 0;
    static String label_col_name = null;

    public static DataFrame<Object> processCommandline(String[] strArr) throws IOException {
        DataFrame<Object> readCsv;
        PosixParser posixParser = new PosixParser();
        Options options = new Options();
        options.addOption("perp", "perplexity", true, "set the perplexity of the t-SNE algorithm (default " + perplexity + DefaultExpressionEngine.DEFAULT_INDEX_END);
        options.addOption("idims", "initial_dims", true, "scale the dataset to initial dims with PCA before running t-SNE (default " + initial_dims + "). Negative number indicates no scaling");
        options.addOption("nolbls", "no_labels", false, "The dataset does not contain any labels (if not set labels are assumed to be in the first column)");
        options.addOption("lblf", "label_file", true, "Separate input file with dataset labels, one label per row, must contain at least as many rows as in the dataset. Extra labels will be thrown away");
        options.addOption("na", "na_string", true, "Dataset can contain N/A, the given string is parsed as N/A in the dataset");
        options.addOption("nohdr", "no_headers", false, "If set, won't try to read a first row of column headers / names");
        options.addOption(Base.LOGARITHM, "scale_log", false, "Scale the dataset by first taking the log of each datapoint (keeping zeros) ");
        options.addOption("norm", "normalize", false, "Normalize the data by subtracting the mean and dividing by the stdev (this is done after eventual log) ");
        options.addOption("iter", "iterations", true, "How many iterations to run, default is " + iterations);
        options.addOption("noplt", "no_plot", false, "Don't plot the resulting dataset ");
        options.addOption("lblcolno", "label_column_no", true, "If labels are not in first column, this option gives the index of the label column");
        options.addOption("lblcolnme", "label_column_name", true, "If labels are not in first column, this option gives the name of the label column. Requires headers in the dataset");
        options.addOption("shw", "show", false, "Show displays the tabular data of a data frame in a gui window ");
        options.addOption("dn", "drop_name", true, "drop column names. Takes a list of names (Example: \"Customer Name,Comment,Id\") representing the cloumn names to drop. This is done AFTER any drop_column!");
        options.addOption(DublinCoreSchema.DEFAULT_XPATH_ID, "drop_column", true, "drop column no's. Takes a list of integers (Example: \"1,2,8,11\") representing the cloumns to drop");
        options.addOption("sep", "separator", true, "column separator ',' , ';' , '\\t' (',' per default). '\\t' denotes tab");
        options.addOption("dbl", "double_default", false, "use Double as number format (Long is default but even with Long default, numbers with decimals will still be parsed as Double)");
        options.addOption("trsp", "transpose", false, "transpose the dataset first");
        options.addOption("trspa", "transpose_after", false, "transpose the dataset after t-SNE is done");
        options.addOption("out", "output_file", true, "Save the result to the given filename");
        options.addOption("no", "noise", false, "add a small amount of noise to each column. This can be useful with highly structured datasets which can otherwise cause problems");
        options.addOption("ss", "subsample", false, "the current implementation does not handle very large datasets due to memory and time constraints. Adding this flag will uniformly subsample the dataset");
        options.addOption("odim", "output_dims", true, "Alternatives are '2D' or '3D' default is (" + output_dims + "D ");
        CommandLine commandLine = null;
        HelpFormatter helpFormatter = new HelpFormatter();
        try {
            commandLine = posixParser.parse(options, strArr);
        } catch (ParseException e) {
            System.out.println("TSneCsv: Could not parse command line due to:  " + e.getMessage());
            System.out.println("Args where:");
            for (String str : strArr) {
                System.out.print(str + ", ");
            }
            helpFormatter.printHelp("TSneCsv [options] <csv file>", options);
            System.exit(-1);
        }
        if (commandLine.getArgs().length == 0) {
            System.out.println("No CSV file given...");
            helpFormatter.printHelp("TSneCsv [options] <csv file>", options);
            System.exit(255);
        }
        ArrayList arrayList = new ArrayList();
        String trim = commandLine.hasOption("separator") ? commandLine.getOptionValue("separator").trim() : StringArrayPropertyEditor.DEFAULT_SEPARATOR;
        if (!trim.equals(StringArrayPropertyEditor.DEFAULT_SEPARATOR) && !trim.equals(";") && !trim.equals("\\t")) {
            System.out.println("Only the separators ',' , ';' or '\\t' is currently supported...");
            helpFormatter.printHelp(DataFrame.class.getCanonicalName() + " [options] <csv file>", options);
            System.exit(255);
        }
        String str2 = commandLine.getArgs()[0];
        if (commandLine.hasOption("iterations")) {
            iterations = Integer.parseInt(commandLine.getOptionValue("iterations").trim());
        }
        if (commandLine.hasOption("label_column_no")) {
            label_col_no = Integer.parseInt(commandLine.getOptionValue("label_column_no").trim());
            label_col_no--;
        }
        if (commandLine.hasOption("label_column_name")) {
            label_col_name = commandLine.getOptionValue("label_column_name");
        }
        if (commandLine.hasOption("no_headers")) {
            hasHeader = false;
        }
        if (commandLine.hasOption("na_string")) {
            naString = commandLine.getOptionValue("na_string").trim();
        }
        System.out.println("TSneCsv: Running " + iterations + " iterations of t-SNE on " + str2);
        System.out.println("NA string is: " + naString);
        if (commandLine.hasOption("double_default")) {
            readCsv = DataFrame.readCsv(str2, trim, DataFrame.NumberDefault.DOUBLE_DEFAULT, naString, hasHeader);
            arrayList.add(readCsv);
        } else {
            readCsv = DataFrame.readCsv(str2, trim, DataFrame.NumberDefault.LONG_DEFAULT, naString, hasHeader);
            arrayList.add(readCsv);
        }
        System.out.println("Loaded CSV with: " + readCsv.length() + " rows and " + readCsv.size() + " columns.");
        if (commandLine.hasOption("transpose")) {
            readCsv = readCsv.transpose();
        }
        if (commandLine.hasOption("perplexity")) {
            perplexity = Double.parseDouble(commandLine.getOptionValue("perplexity").trim());
        }
        if (commandLine.hasOption("initial_dims")) {
            initial_dims = Integer.parseInt(commandLine.getOptionValue("initial_dims").trim());
        }
        if (commandLine.hasOption("no_labels")) {
            hasLabels = false;
        }
        if (commandLine.hasOption("label_file")) {
            label_fn = commandLine.getOptionValue("label_file").trim();
        }
        if (commandLine.hasOption("scale_log")) {
            System.out.println("Log transforming dataset...");
            scale_log = true;
        }
        if (commandLine.hasOption("normalize")) {
            System.out.println("Normalizing dataset...");
            normalize = true;
        }
        if (commandLine.hasOption("noise")) {
            System.out.println("Adding noise...");
            addNoise = true;
        }
        if (commandLine.hasOption("subsample")) {
            System.out.println("Subsampling dataset...");
            subSample = true;
        }
        if (commandLine.hasOption("output_dims")) {
            String trim2 = commandLine.getOptionValue("output_dims").trim();
            if (trim2.equalsIgnoreCase("2d")) {
                output_dims = 2;
            } else {
                if (!trim2.equalsIgnoreCase("3d")) {
                    throw new IllegalArgumentException("Only legal output_dims options are '2D' or '3D'.");
                }
                output_dims = 3;
            }
        }
        if (commandLine.hasOption("drop_column")) {
            String[] split = commandLine.getOptionValue("drop_column").split(StringArrayPropertyEditor.DEFAULT_SEPARATOR);
            ArrayList arrayList2 = new ArrayList();
            for (String str3 : split) {
                String trim3 = str3.trim();
                if (trim3.length() > 0) {
                    arrayList2.add(Integer.valueOf(Integer.parseInt(trim3)));
                }
            }
            Integer[] numArr = (Integer[]) arrayList2.toArray(new Integer[0]);
            if (arrayList2.size() > 0) {
                readCsv = readCsv.drop(numArr);
            }
        }
        if (commandLine.hasOption("drop_name")) {
            for (String str4 : commandLine.getOptionValue("drop_name").split(StringArrayPropertyEditor.DEFAULT_SEPARATOR)) {
                String trim4 = str4.trim();
                if (trim4.length() > 0) {
                    readCsv = readCsv.drop(trim4);
                }
            }
        }
        if (commandLine.hasOption("no_plot")) {
            doPlot = false;
        }
        if (commandLine.hasOption("output_file")) {
            output_fn = commandLine.getOptionValue("output_file");
            doSave = true;
        }
        final double length = readCsv.length();
        if (subSample) {
            readCsv = readCsv.select(new DataFrame.Predicate<Object>() { // from class: com.jujutsu.tsne.demos.TSneCsv.1
                double prob;
                Random rnd = new Random();

                {
                    this.prob = Math.min(1.0d, 2500.0d / length);
                }

                @Override // joinery.DataFrame.Function
                public Boolean apply(List<Object> list) {
                    return Boolean.valueOf(this.rnd.nextDouble() < this.prob);
                }
            });
        }
        if (!commandLine.hasOption("show") || arrayList.size() != 1) {
            return readCsv;
        }
        ((DataFrame) arrayList.get(0)).show();
        return readCsv;
    }

    static void tsne_csv(String[] strArr) throws IOException {
        List<Object> col;
        DataFrame<Object> processCommandline = processCommandline(strArr);
        System.out.println("Dataset types:" + processCommandline.types());
        System.out.println(processCommandline.head(10));
        String[] strArr2 = null;
        if (hasLabels) {
            strArr2 = new String[processCommandline.length()];
            int i = 0;
            if (label_col_name != null) {
                System.out.println("Using labels from colum name: " + label_col_name);
                col = processCommandline.col(label_col_name);
            } else {
                System.out.println("Using labels from colum index: " + (label_col_no + 1));
                col = processCommandline.col(Integer.valueOf(label_col_no));
            }
            Iterator<Object> it2 = col.iterator();
            while (it2.hasNext()) {
                int i2 = i;
                i++;
                strArr2[i2] = it2.next().toString();
            }
            processCommandline = label_col_name != null ? processCommandline.drop(label_col_name) : processCommandline.drop(Integer.valueOf(label_col_no));
        }
        if (!hasLabels && label_fn != null) {
            System.out.println("Loading labels from:" + label_fn);
            String[] simpleReadLines = MatrixUtils.simpleReadLines(new File(label_fn));
            strArr2 = simpleReadLines.length > processCommandline.length() ? (String[]) Arrays.copyOf(simpleReadLines, processCommandline.length()) : simpleReadLines;
        }
        int size = transpose_after ? processCommandline.size() : processCommandline.length();
        if (strArr2 != null && strArr2.length < size) {
            throw new IllegalArgumentException("The number of labels (" + strArr2.length + ") is not the same (or more) as the number of rows in the dataset (" + size + ").");
        }
        double[][] modelMatrix = processCommandline.cast(Double.class).toModelMatrix(0.0d);
        if (scale_log) {
            modelMatrix = MatrixOps.log(modelMatrix, true);
        }
        if (normalize) {
            modelMatrix = MatrixOps.centerAndScale(modelMatrix);
        }
        if (addNoise) {
            modelMatrix = MatrixOps.addNoise(modelMatrix);
        }
        System.out.println(MatrixOps.doubleArrayToPrintString(modelMatrix, 5, 5, 20));
        MemOptimizedTSne memOptimizedTSne = new MemOptimizedTSne();
        long currentTimeMillis = System.currentTimeMillis();
        double[][] tsne = memOptimizedTSne.tsne(TSneUtils.buildConfig(modelMatrix, output_dims, initial_dims, perplexity, iterations));
        if (transpose_after) {
            tsne = MatrixOps.transposeSerial(modelMatrix);
        }
        System.out.println("TSne took: " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + " seconds");
        if (doSave) {
            if (strArr2 != null) {
                DataFrame dataFrame = new DataFrame("label", "X", "Y");
                int i3 = 0;
                for (double[] dArr : tsne) {
                    int i4 = i3;
                    i3++;
                    dataFrame.append(Arrays.asList(strArr2[i4], Double.valueOf(dArr[0]), Double.valueOf(dArr[1])));
                }
                System.out.println(dataFrame);
                dataFrame.writeCsv(output_fn);
            } else {
                DataFrame dataFrame2 = new DataFrame("X", "Y");
                for (double[] dArr2 : tsne) {
                    dataFrame2.append(Arrays.asList(Double.valueOf(dArr2[0]), Double.valueOf(dArr2[1])));
                }
                System.out.println(dataFrame2);
                dataFrame2.writeCsv(output_fn);
            }
        }
        if (doPlot) {
            plot2D(strArr2, tsne);
        }
    }

    static void plot2D(String[] strArr, double[][] dArr) {
        Plot2DPanel plot2DPanel = new Plot2DPanel();
        if (strArr != null) {
            plot2DPanel.plotCanvas.addPlot(new ColoredScatterPlot("TSne Result", dArr, strArr));
        } else {
            plot2DPanel.plotCanvas.addPlot(new ScatterPlot("Data", PlotPanel.COLORLIST[0], dArr));
        }
        plot2DPanel.plotCanvas.setNotable(true);
        plot2DPanel.plotCanvas.setNoteCoords(true);
        FrameView frameView = new FrameView(plot2DPanel);
        frameView.setDefaultCloseOperation(3);
        frameView.setVisible(true);
    }

    public static void printMtx(double[][] dArr) {
        for (double[] dArr2 : dArr) {
            for (int i = 0; i < dArr[0].length; i++) {
                System.out.print(dArr2[i] + ", ");
            }
            System.out.println();
        }
    }

    public static void printMtx(Object[] objArr) {
        for (int i = 0; i < objArr.length; i++) {
            System.out.print(objArr[i].getClass() + "=>" + objArr[i] + ", ");
        }
        System.out.println();
    }

    public static void main(String[] strArr) throws IOException {
        tsne_csv(strArr);
    }
}
