/*
 * Decompiled with CFR 0.152.
 */
package smile.manifold;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.stat.distribution.GaussianDistribution;
import smile.util.MulticoreExecutor;

public class TSNE {
    private static final Logger logger = LoggerFactory.getLogger(TSNE.class);
    private double[][] coordinates;
    private double eta = 500.0;
    private double momentum = 0.5;
    private double finalMomentum = 0.8;
    private int momentumSwitchIter = 250;
    private double minGain = 0.01;
    private int totalIter = 1;
    private double[][] D;
    private double[][] dY;
    private double[][] gains;
    private double[][] P;
    private double[][] Q;
    private double Qsum;

    public TSNE(double[][] X, int d) {
        this(X, d, 20.0, 200.0, 1000);
    }

    public TSNE(double[][] X, int d, double perplexity, double eta, int iterations) {
        this.eta = eta;
        int n = X.length;
        if (X.length == X[0].length) {
            this.D = X;
        } else {
            this.D = new double[n][n];
            Math.pdist(X, this.D, true, false);
        }
        double[][] Y = this.coordinates = new double[n][d];
        this.dY = new double[n][d];
        this.gains = new double[n][d];
        GaussianDistribution gaussian = new GaussianDistribution(0.0, 1.0E-4);
        for (int i = 0; i < n; ++i) {
            Arrays.fill(this.gains[i], 1.0);
            double[] Yi = Y[i];
            for (int j = 0; j < d; ++j) {
                Yi[j] = gaussian.rand();
            }
        }
        this.P = this.expd(this.D, perplexity, 0.001);
        this.Q = new double[n][n];
        double Psum = 2 * n;
        for (int i = 0; i < n; ++i) {
            double[] Pi = this.P[i];
            for (int j = 0; j < i; ++j) {
                double p = 12.0 * (Pi[j] + this.P[j][i]) / Psum;
                if (Double.isNaN(p) || p < 1.0E-16) {
                    p = 1.0E-16;
                }
                Pi[j] = p;
                this.P[j][i] = p;
            }
        }
        this.learn(iterations);
    }

    public void learn(int iterations) {
        double[][] Y = this.coordinates;
        int n = Y.length;
        int d = Y[0].length;
        int nprocs = MulticoreExecutor.getThreadPoolSize();
        int chunk = n / nprocs;
        ArrayList<SNETask> tasks = new ArrayList<SNETask>();
        for (int i = 0; i < nprocs; ++i) {
            int start = i * chunk;
            int end = i == nprocs - 1 ? n : (i + 1) * chunk;
            SNETask task = new SNETask(start, end);
            tasks.add(task);
        }
        int iter = 1;
        while (iter <= iterations) {
            Math.pdist(Y, this.Q, true, false);
            this.Qsum = 0.0;
            for (int i = 0; i < n; ++i) {
                double[] Qi = this.Q[i];
                for (int j = 0; j < i; ++j) {
                    double q;
                    Qi[j] = q = 1.0 / (1.0 + Qi[j]);
                    this.Q[j][i] = q;
                    this.Qsum += q;
                }
            }
            this.Qsum *= 2.0;
            try {
                MulticoreExecutor.run(tasks);
            }
            catch (Exception e) {
                logger.error("t-SNE iteration task fails: {}", e);
            }
            if (this.totalIter == this.momentumSwitchIter) {
                this.momentum = this.finalMomentum;
                for (int i = 0; i < n; ++i) {
                    double[] Pi = this.P[i];
                    int j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / 12.0;
                    }
                }
            }
            if (iter % 50 == 0) {
                double C = 0.0;
                for (int i = 0; i < n; ++i) {
                    double[] Pi = this.P[i];
                    double[] Qi = this.Q[i];
                    for (int j = 0; j < i; ++j) {
                        double p = Pi[j];
                        double q = Qi[j] / this.Qsum;
                        if (Double.isNaN(q) || q < 1.0E-16) {
                            q = 1.0E-16;
                        }
                        C += p * Math.log2(p / q);
                    }
                }
                logger.info("Error after {} iterations: {}", (Object)this.totalIter, (Object)(2.0 * C));
            }
            ++iter;
            ++this.totalIter;
        }
        double[] colMeans = Math.colMeans(Y);
        for (int i = 0; i < n; ++i) {
            double[] Yi = Y[i];
            for (int j = 0; j < d; ++j) {
                int n3 = j;
                Yi[n3] = Yi[n3] - colMeans[j];
            }
        }
    }

    private double[][] expd(double[][] D, double perplexity, double tol) {
        int n = D.length;
        double[][] P = new double[n][n];
        double[] DiSum = Math.rowSums(D);
        int nprocs = MulticoreExecutor.getThreadPoolSize();
        int chunk = n / nprocs;
        ArrayList<PerplexityTask> tasks = new ArrayList<PerplexityTask>();
        for (int i = 0; i < nprocs; ++i) {
            int start = i * chunk;
            int end = i == nprocs - 1 ? n : (i + 1) * chunk;
            PerplexityTask task = new PerplexityTask(start, end, D, P, DiSum, perplexity, tol);
            tasks.add(task);
        }
        try {
            MulticoreExecutor.run(tasks);
        }
        catch (Exception e) {
            logger.error("t-SNE Gaussian kernel width search task fails: {}", e);
        }
        return P;
    }

    public double[][] getCoordinates() {
        return this.coordinates;
    }

    private class PerplexityTask
    implements Callable<Void> {
        int start;
        int end;
        double[][] D;
        double[][] P;
        double[] DiSum;
        double perplexity;
        double tol;

        PerplexityTask(int start, int end, double[][] D, double[][] P, double[] DiSum, double perplexity, double tol) {
            this.start = start;
            this.end = end;
            this.D = D;
            this.P = P;
            this.DiSum = DiSum;
            this.perplexity = perplexity;
            this.tol = tol;
        }

        @Override
        public Void call() {
            for (int i = this.start; i < this.end; ++i) {
                this.compute(i);
            }
            return null;
        }

        private void compute(int i) {
            int n = this.D.length;
            double logU = Math.log2(this.perplexity);
            double[] Pi = this.P[i];
            double[] Di = this.D[i];
            double beta = Math.sqrt((double)(n - 1) / this.DiSum[i]);
            double betamin = 0.0;
            double betamax = Double.POSITIVE_INFINITY;
            logger.debug("initial beta[{}] = {}", (Object)i, (Object)beta);
            double Hdiff = Double.MAX_VALUE;
            for (int iter = 0; Math.abs(Hdiff) > this.tol && iter < 50; ++iter) {
                int j;
                double Pisum = 0.0;
                double H = 0.0;
                for (j = 0; j < n; ++j) {
                    double p;
                    double d = beta * Di[j];
                    Pi[j] = p = Math.exp(-d);
                    Pisum += p;
                    H += p * d;
                }
                Pi[i] = 0.0;
                if (Math.abs(Hdiff = (H = Math.log2(Pisum -= 1.0) + H / Pisum) - logU) > this.tol) {
                    if (Hdiff > 0.0) {
                        betamin = beta;
                        beta = Double.isInfinite(betamax) ? (beta *= 2.0) : (beta + betamax) / 2.0;
                    } else {
                        betamax = beta;
                        beta = (beta + betamin) / 2.0;
                    }
                } else {
                    j = 0;
                    while (j < n) {
                        int n2 = j++;
                        Pi[n2] = Pi[n2] / Pisum;
                    }
                }
                logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", Hdiff, i, beta, H, logU);
            }
        }
    }

    private class SNETask
    implements Callable<Void> {
        int start;
        int end;
        double[] dC;

        SNETask(int start, int end) {
            this.start = start;
            this.end = end;
            this.dC = new double[TSNE.this.coordinates[0].length];
        }

        @Override
        public Void call() {
            for (int i = this.start; i < this.end; ++i) {
                this.compute(i);
            }
            return null;
        }

        private void compute(int i) {
            double[][] Y = TSNE.this.coordinates;
            int n = Y.length;
            int d = Y[0].length;
            Arrays.fill(this.dC, 0.0);
            double[] Yi = Y[i];
            double[] Pi = TSNE.this.P[i];
            double[] Qi = TSNE.this.Q[i];
            double[] dYi = TSNE.this.dY[i];
            double[] g = TSNE.this.gains[i];
            for (int j = 0; j < n; ++j) {
                if (i == j) continue;
                double[] Yj = Y[j];
                double q = Qi[j];
                double z = (Pi[j] - q / TSNE.this.Qsum) * q;
                for (int k = 0; k < d; ++k) {
                    int n2 = k;
                    this.dC[n2] = this.dC[n2] + 4.0 * (Yi[k] - Yj[k]) * z;
                }
            }
            for (int k = 0; k < d; ++k) {
                double d2 = g[k] = Math.signum(this.dC[k]) != Math.signum(dYi[k]) ? g[k] + 0.2 : g[k] * 0.8;
                if (g[k] < TSNE.this.minGain) {
                    g[k] = TSNE.this.minGain;
                }
                int n3 = k;
                Yi[n3] = Yi[n3] + dYi[k];
                dYi[k] = TSNE.this.momentum * dYi[k] - TSNE.this.eta * g[k] * this.dC[k];
            }
        }
    }
}

