/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.datatransform.visualization.VisualizationTransform;
import jsat.linear.DenseMatrix;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class MDS
implements VisualizationTransform {
    private static DistanceMetric embedMetric = new EuclideanDistance();
    private DistanceMetric dm = new EuclideanDistance();
    private double tolerance = 0.001;
    private int maxIterations = 300;
    private int targetSize = 2;

    public void setTolerance(double tolerance) {
        if (tolerance < 0.0 || Double.isInfinite(tolerance) || Double.isNaN(tolerance)) {
            throw new IllegalArgumentException("tolerance must be a non-negative value, not " + tolerance);
        }
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setEmbeddingMetric(DistanceMetric embedMetric) {
        MDS.embedMetric = embedMetric;
    }

    public DistanceMetric getEmbeddingMetric() {
        return embedMetric;
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d, boolean parallel) {
        List<Vec> orig_vecs = d.getDataVectors();
        List<Double> orig_distCache = this.dm.getAccelerationCache(orig_vecs, parallel);
        int N = orig_vecs.size();
        DenseMatrix delta = new DenseMatrix(N, N);
        OnLineStatistics avg = ParallelUtils.run(parallel, N, i -> {
            OnLineStatistics local_avg = new OnLineStatistics();
            for (int j = i + 1; j < d.getSampleSize(); ++j) {
                double dist = this.dm.dist(i, j, (List<? extends Vec>)orig_vecs, orig_distCache);
                local_avg.add(dist);
                delta.set(i, j, dist);
                delta.set(j, i, dist);
            }
            return local_avg;
        }, (a, b) -> OnLineStatistics.add(a, b));
        SimpleDataSet embeded = this.transform(delta, parallel);
        DataSet<Type> transformed = d.shallowClone();
        transformed.replaceNumericFeatures(embeded.getDataVectors());
        return (Type)transformed;
    }

    public SimpleDataSet transform(Matrix delta) {
        return this.transform(delta, false);
    }

    public SimpleDataSet transform(Matrix delta, boolean parallel) {
        int N = delta.rows();
        Random rand = RandomUtil.getRandom();
        DenseMatrix X = new DenseMatrix(N, this.targetSize);
        ArrayList<Vec> X_views = new ArrayList<Vec>();
        for (int i2 = 0; i2 < N; ++i2) {
            for (int j = 0; j < this.targetSize; ++j) {
                ((Matrix)X).set(i2, j, rand.nextDouble());
            }
            X_views.add(((Matrix)X).getRowView(i2));
        }
        List<Double> X_rowCache = embedMetric.getAccelerationCache(X_views, parallel);
        DenseMatrix V_inv = new DenseMatrix(N, N);
        for (int i3 = 0; i3 < N; ++i3) {
            for (int j = 0; j < N; ++j) {
                if (i3 == j) {
                    ((Matrix)V_inv).set(i3, j, (1.0 - 1.0 / (double)N) / (double)N);
                    continue;
                }
                ((Matrix)V_inv).set(i3, j, (0.0 - 1.0 / (double)N) / (double)N);
            }
        }
        double stressChange = Double.POSITIVE_INFINITY;
        double oldStress = MDS.stress(X_views, X_rowCache, delta, parallel);
        DenseMatrix B = new DenseMatrix(N, N);
        DenseMatrix X_new = new DenseMatrix(((Matrix)X).rows(), ((Matrix)X).cols());
        for (int iter = 0; iter < this.maxIterations && stressChange > this.tolerance; ++iter) {
            ParallelUtils.run(parallel, ((Matrix)B).rows(), i -> {
                for (int j = i + 1; j < B.rows(); ++j) {
                    double d_ij = embedMetric.dist(i, j, (List<? extends Vec>)X_views, X_rowCache);
                    if (d_ij > 1.0E-5) {
                        double b_ij = -delta.get(i, j) / d_ij;
                        B.set(i, j, b_ij);
                        B.set(j, i, b_ij);
                        continue;
                    }
                    B.set(i, j, 0.0);
                    B.set(j, i, 0.0);
                }
            });
            ((Matrix)X_new).zeroOut();
            for (int i4 = 0; i4 < ((Matrix)B).rows(); ++i4) {
                ((Matrix)B).set(i4, i4, 0.0);
                for (int k = 0; k < ((Matrix)B).cols(); ++k) {
                    if (k == i4) continue;
                    B.increment(i4, i4, -((Matrix)B).get(i4, k));
                }
            }
            ((Matrix)B).multiply(X, X_new, ParallelUtils.CACHED_THREAD_POOL);
            ((Matrix)X_new).mutableMultiply(1.0 / (double)N);
            X_new.copyTo(X);
            X_rowCache.clear();
            X_rowCache.addAll(embedMetric.getAccelerationCache(X_views, parallel));
            double newStress = MDS.stress(X_views, X_rowCache, delta, parallel);
            stressChange = Math.abs(oldStress - newStress);
            oldStress = newStress;
        }
        SimpleDataSet sds = new SimpleDataSet(new CategoricalData[0], this.targetSize);
        for (Vec v : X_views) {
            sds.add(new DataPoint(v));
        }
        return sds;
    }

    private static double stress(List<Vec> X_views, List<Double> X_rowCache, Matrix delta, boolean parallel) {
        return ParallelUtils.run(parallel, delta.rows(), i -> {
            double localStress = 0.0;
            for (int j = i + 1; j < delta.rows(); ++j) {
                double tmp = embedMetric.dist(i, j, (List<? extends Vec>)X_views, X_rowCache) - delta.get(i, j);
                localStress += tmp * tmp;
            }
            return localStress;
        }, (a, b) -> a + b);
    }

    @Override
    public int getTargetDimension() {
        return this.targetSize;
    }

    @Override
    public boolean setTargetDimension(int target) {
        if (target < 1) {
            return false;
        }
        this.targetSize = target;
        return true;
    }
}

