package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.neuralnetwork.LVQ;
import jsat.datatransform.DataTransform;
import jsat.distributions.Uniform;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/datatransform/visualization/LargeViz.class */
public class LargeViz implements VisualizationTransform {
    private DistanceMetric dm_source = new EuclideanDistance();
    private DistanceMetric dm_embed = new EuclideanDistance();
    private double perplexity = 50.0d;
    private int dt = 2;
    private int M = 5;
    private double gamma = 7.0d;

    public void setPerplexity(double d) {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("perplexity must be positive, not " + d);
        }
        this.perplexity = d;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public void setDistanceMetricSource(DistanceMetric distanceMetric) {
        this.dm_source = distanceMetric;
    }

    public void setDistanceMetricEmbedding(DistanceMetric distanceMetric) {
        this.dm_embed = distanceMetric;
    }

    public void setNegativeSamples(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of negative samples must be positive, not " + i);
        }
        this.M = i;
    }

    public int getNegativeSamples() {
        return this.M;
    }

    public void setGamma(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Gamma must be positive, not " + d);
        }
        this.gamma = d;
    }

    public double getGamma() {
        return this.gamma;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public int getTargetDimension() {
        return this.dt;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public boolean setTargetDimension(int i) {
        if (i < 2) {
            return false;
        }
        this.dt = i;
        return true;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public <Type extends DataSet> Type transform(DataSet<Type> dataSet, boolean z) {
        Random random = RandomUtil.getRandom();
        ThreadLocal withInitial = ThreadLocal.withInitial(RandomUtil::getRandom);
        int sampleSize = dataSet.getSampleSize();
        int min = (int) Math.min(Math.floor(3.0d * this.perplexity), sampleSize - 1);
        double[][] dArr = new double[sampleSize][min];
        int[][] iArr = new int[sampleSize][min];
        TSNE.computeP(dataSet, z, random, min, iArr, dArr, this.dm_source, this.perplexity);
        double[][] dArr2 = new double[sampleSize][min];
        double[] dArr3 = new double[sampleSize];
        double d = 0.0d;
        for (int i = 0; i < sampleSize; i++) {
            double sum = DenseVector.toDenseVec(dArr[i]).sum() + (dArr[i].length * Double.MIN_VALUE);
            dArr3[i] = sum;
            dArr2[i][0] = dArr[i][0];
            for (int i2 = 1; i2 < min; i2++) {
                dArr2[i][i2] = Math.ulp(dArr[i][i2]) + dArr[i][i2] + dArr2[i][i2 - 1];
            }
            for (int i3 = 1; i3 < min; i3++) {
                double[] dArr4 = dArr2[i];
                int i4 = i3;
                dArr4[i4] = dArr4[i4] / sum;
            }
            dArr3[i] = Math.pow(dArr3[i], 0.75d);
            d += dArr3[i];
            if (i > 0) {
                int i5 = i;
                dArr3[i5] = dArr3[i5] + dArr3[i - 1];
            }
        }
        for (int i6 = 0; i6 < sampleSize; i6++) {
            int i7 = i6;
            dArr3[i7] = dArr3[i7] / d;
        }
        final ArrayList arrayList = new ArrayList();
        Uniform uniform = new Uniform((-5.0E-5d) / this.dt, 5.0E-5d / this.dt);
        for (int i8 = 0; i8 < sampleSize; i8++) {
            arrayList.add(uniform.sampleVec(this.dt, random));
        }
        new CountDownLatch(Math.max(Math.min(sampleSize / (LVQ.DEFAULT_ITERATIONS * this.M), SystemInfo.LogicalCores), 1));
        long j = 1000 * sampleSize;
        ThreadLocal withInitial2 = ThreadLocal.withInitial(() -> {
            return new DenseVector(this.dt);
        });
        ThreadLocal withInitial3 = ThreadLocal.withInitial(() -> {
            return new DenseVector(this.dt);
        });
        ThreadLocal withInitial4 = ThreadLocal.withInitial(() -> {
            return new DenseVector(this.dt);
        });
        AtomicLong atomicLong = new AtomicLong();
        ParallelUtils.run(z, sampleSize, (i9, i10) -> {
            int binarySearch;
            Random random2 = (Random) withInitial.get();
            for (int i9 = 0; i9 < 1000; i9++) {
                for (int i10 = i9; i10 < i10; i10++) {
                    double max = Math.max(1.0d * (1.0d - (atomicLong.getAndIncrement() / j)), 1.0E-4d);
                    int nextInt = random2.nextInt(sampleSize);
                    int binarySearch2 = Arrays.binarySearch(dArr2[nextInt], random2.nextDouble());
                    if (binarySearch2 < 0) {
                        binarySearch2 = (-binarySearch2) - 1;
                    }
                    if (binarySearch2 >= min) {
                        binarySearch2 = random2.nextInt(min);
                    }
                    int i11 = iArr[nextInt][binarySearch2];
                    Vec vec = (Vec) arrayList.get(nextInt);
                    Vec vec2 = (Vec) arrayList.get(i11);
                    double dist = this.dm_embed.dist(nextInt, i11, (List<? extends Vec>) arrayList, (List<Double>) null);
                    double d2 = dist * dist;
                    if (dist > 0.0d) {
                        Vec vec3 = (Vec) withInitial2.get();
                        Vec vec4 = (Vec) withInitial3.get();
                        Vec vec5 = (Vec) withInitial4.get();
                        vec.copyTo(vec4);
                        vec4.mutableSubtract(vec2);
                        vec4.mutableMultiply(((-2.0d) * dist) / (d2 + 1.0d));
                        vec4.copyTo(vec3);
                        for (int i12 = 0; i12 < this.M; i12++) {
                            do {
                                binarySearch = Arrays.binarySearch(dArr3, random2.nextDouble());
                                if (binarySearch < 0) {
                                    binarySearch = (-binarySearch) - 1;
                                }
                                if (binarySearch == nextInt || binarySearch == i11) {
                                    binarySearch = -1;
                                }
                                int i13 = 0;
                                while (true) {
                                    if (i13 >= iArr[nextInt].length) {
                                        break;
                                    }
                                    if (iArr[nextInt][i13] == binarySearch && dArr2[nextInt][i13] < 0.98d) {
                                        binarySearch = -1;
                                        break;
                                    }
                                    i13++;
                                }
                            } while (binarySearch < 0);
                            Vec vec6 = (Vec) arrayList.get(binarySearch);
                            double dist2 = this.dm_embed.dist(nextInt, binarySearch, (List<? extends Vec>) arrayList, (List<Double>) null);
                            double d3 = dist2 * dist2;
                            if (dist2 >= 1.0E-12d) {
                                vec.copyTo(vec5);
                                vec5.mutableSubtract(vec6);
                                vec5.mutableMultiply((2.0d * this.gamma) / (dist2 * (d3 + 1.0d)));
                                vec3.mutableAdd(vec5);
                                vec6.mutableSubtract(max, vec5);
                            }
                        }
                        vec.mutableAdd(max, vec3);
                        vec2.mutableAdd(-max, vec4);
                    }
                }
            }
        });
        DataSet<Type> shallowClone2 = dataSet.shallowClone2();
        final IdentityHashMap identityHashMap = new IdentityHashMap(sampleSize);
        for (int i11 = 0; i11 < sampleSize; i11++) {
            identityHashMap.put(dataSet.getDataPoint(i11), Integer.valueOf(i11));
        }
        shallowClone2.applyTransform(new DataTransform() { // from class: jsat.datatransform.visualization.LargeViz.1
            @Override // jsat.datatransform.DataTransform
            public DataPoint transform(DataPoint dataPoint) {
                return new DataPoint((Vec) arrayList.get(((Integer) identityHashMap.get(dataPoint)).intValue()), dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
            }

            @Override // jsat.datatransform.DataTransform
            public void fit(DataSet dataSet2) {
            }

            @Override // jsat.datatransform.DataTransform
            public DataTransform clone() {
                return this;
            }
        });
        return shallowClone2;
    }
}
