/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.DoubleAdder;
import jsat.DataSet;
import jsat.clustering.ClustererBase;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.utils.IntSet;
import jsat.utils.concurrent.ParallelUtils;

public class FLAME
extends ClustererBase
implements Parameterized {
    private static final long serialVersionUID = 2393091020100706517L;
    private DistanceMetric dm;
    private int k;
    private int maxIterations;
    private VectorCollection<VecPaired<Vec, Integer>> vc = new DefaultVectorCollection<VecPaired<Vec, Integer>>();
    private double stndDevs = 2.5;
    private double eps = 1.0E-6;

    public FLAME(DistanceMetric dm, int k, int maxIterations) {
        this.setDistanceMetric(dm);
        this.setK(k);
        this.setMaxIterations(maxIterations);
    }

    public FLAME(FLAME toCopy) {
        this.dm = toCopy.dm.clone();
        this.maxIterations = toCopy.maxIterations;
        this.vc = toCopy.vc.clone();
        this.k = toCopy.k;
        this.stndDevs = toCopy.stndDevs;
        this.eps = toCopy.eps;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            throw new IllegalArgumentException("Must perform a positive number of iterations, not " + maxIterations);
        }
        this.maxIterations = maxIterations;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setK(int k) {
        this.k = k;
    }

    public int getK() {
        return this.k;
    }

    public void setEps(double eps) {
        if (Double.isNaN(eps)) {
            throw new IllegalArgumentException("Eps can not be NaN");
        }
        this.eps = eps;
    }

    public double getEps() {
        return this.eps;
    }

    public void setStndDevs(double stndDevs) {
        if (stndDevs < 0.0 || Double.isInfinite(stndDevs) || Double.isNaN(stndDevs)) {
            throw new IllegalArgumentException("Standard Deviations must be non negative");
        }
        this.stndDevs = stndDevs;
    }

    public double getStndDevs() {
        return this.stndDevs;
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setVectorCollectionFactory(VectorCollection<VecPaired<Vec, Integer>> vc) {
        this.vc = vc;
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        int j;
        if (this.k >= dataSet.getSampleSize()) {
            throw new FailedToFitException("Number of k-neighbors (" + this.k + ") can not be larger than the number of datapoints (" + dataSet.getSampleSize() + ")");
        }
        int n = dataSet.getSampleSize();
        if (designations == null || designations.length < dataSet.getSampleSize()) {
            designations = new int[n];
        }
        ArrayList<VecPaired<Vec, Integer>> vecs = new ArrayList<VecPaired<Vec, Integer>>(n);
        for (int i2 = 0; i2 < dataSet.getSampleSize(); ++i2) {
            vecs.add(new VecPaired<Vec, Integer>(dataSet.getDataPoint(i2).getNumericalValues(), i2));
        }
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, parallel);
        this.vc.build(parallel, vecs, this.dm);
        List<List<VecPaired<VecPaired<Vec, Integer>, Double>>> allNNs = VectorCollectionUtils.allNearestNeighbors(this.vc, vecs, this.k + 1, parallel);
        double[] density = new double[vecs.size()];
        double[][] weights = new double[n][this.k];
        OnLineStatistics densityStats = new OnLineStatistics();
        for (int i3 = 0; i3 < density.length; ++i3) {
            int j2;
            List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i3);
            for (int j3 = 1; j3 < knns.size(); ++j3) {
                int n2 = i3;
                double d = knns.get(j3).getPair();
                weights[i3][j3 - 1] = d;
                density[n2] = density[n2] + d;
            }
            densityStats.add(density[i3]);
            double sum = 0.0;
            for (j2 = 0; j2 < this.k; ++j2) {
                double d = Math.min(1.0 / Math.pow(weights[i3][j2], 2.0), Double.MAX_VALUE / (double)(this.k + 1));
                weights[i3][j2] = d;
                sum += d;
            }
            j2 = 0;
            while (j2 < this.k) {
                double[] dArray = weights[i3];
                int n3 = j2++;
                dArray[n3] = dArray[n3] / sum;
            }
        }
        HashMap<Integer, Integer> CSOs = new HashMap<Integer, Integer>();
        IntSet outliers = new IntSet();
        Arrays.fill(designations, -1);
        double threshold = densityStats.getMean() + densityStats.getStandardDeviation() * this.stndDevs;
        for (int i4 = 0; i4 < density.length; ++i4) {
            List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i4);
            boolean lowest = true;
            boolean highest = true;
            for (j = 1; j < knns.size() && (highest || lowest); ++j) {
                int jNN = knns.get(j).getVector().getPair();
                if (density[i4] > density[jNN]) {
                    lowest = false;
                    continue;
                }
                highest = false;
            }
            if (lowest) {
                CSOs.put(i4, CSOs.size());
                continue;
            }
            if (!highest || !(density[i4] > threshold)) continue;
            outliers.add(Integer.valueOf(i4));
        }
        int origSize = CSOs.size();
        Iterator iter = CSOs.keySet().iterator();
        block7: while (iter.hasNext()) {
            int i5 = (Integer)iter.next();
            List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i5);
            for (j = 1; j < knns.size(); ++j) {
                if (!outliers.contains(knns.get(j).getVector().getPair())) continue;
                iter.remove();
                continue block7;
            }
        }
        if (origSize != CSOs.size()) {
            IntSet keys = new IntSet(CSOs.keySet());
            CSOs.clear();
            Iterator knns = keys.iterator();
            while (knns.hasNext()) {
                int i6 = (Integer)knns.next();
                CSOs.put(i6, CSOs.size());
            }
        }
        Iterator keys = CSOs.keySet().iterator();
        while (keys.hasNext()) {
            int i7 = (Integer)keys.next();
            designations[i7] = (Integer)CSOs.get(i7);
        }
        double[][] fuzzy = new double[n][CSOs.size() + 1];
        for (int i8 = 0; i8 < n; ++i8) {
            if (CSOs.containsKey(i8)) {
                fuzzy[i8][((Integer)CSOs.get((Object)Integer.valueOf((int)i8))).intValue()] = 1.0;
                continue;
            }
            if (outliers.contains((Object)i8)) {
                fuzzy[i8][CSOs.size()] = 1.0;
                continue;
            }
            Arrays.fill(fuzzy[i8], 1.0 / (double)(CSOs.size() + 1));
        }
        double[][] fuzzy2 = new double[n][CSOs.size() + 1];
        double prevScore = Double.POSITIVE_INFINITY;
        for (int iter2 = 0; iter2 < this.maxIterations; ++iter2) {
            double[][] FROM = fuzzy;
            double[][] TO = fuzzy2;
            DoubleAdder score = new DoubleAdder();
            ParallelUtils.run(parallel, FROM.length, i -> {
                if (outliers.contains((Object)i) || CSOs.containsKey(i)) {
                    return;
                }
                double[] fuzzy2_i = TO[i];
                Arrays.fill(fuzzy2_i, 0.0);
                List knns = (List)allNNs.get(i);
                double sum = 0.0;
                for (int j = 1; j < weights[i].length; ++j) {
                    int jNN = (Integer)((VecPaired)((VecPaired)knns.get(j)).getVector()).getPair();
                    double[] fuzzy_jNN = FROM[jNN];
                    double weight = weights[i][j - 1];
                    for (int z = 0; z < FROM[jNN].length; ++z) {
                        int n = z;
                        fuzzy2_i[n] = fuzzy2_i[n] + weight * fuzzy_jNN[z];
                    }
                }
                for (int z = 0; z < fuzzy2_i.length; ++z) {
                    sum += fuzzy2_i[z];
                }
                double localScore = 0.0;
                for (int z = 0; z < fuzzy2_i.length; ++z) {
                    int n = z;
                    fuzzy2_i[n] = fuzzy2_i[n] / (sum + 1.0E-6);
                    localScore += Math.abs(FROM[i][z] - fuzzy2_i[z]);
                }
                score.add(localScore);
            });
            if (Math.abs(prevScore - score.doubleValue()) < this.eps) break;
            prevScore = score.doubleValue();
            double[][] tmp = fuzzy;
            fuzzy = fuzzy2;
            fuzzy2 = tmp;
        }
        int[] clusterCounts = new int[n];
        for (int i9 = 0; i9 < fuzzy.length; ++i9) {
            int pos = -1;
            double maxVal = 0.0;
            for (int j4 = 0; j4 < fuzzy[i9].length; ++j4) {
                if (!(fuzzy[i9][j4] > maxVal)) continue;
                maxVal = fuzzy[i9][j4];
                pos = j4;
            }
            if (pos == -1) {
                pos = CSOs.size();
            }
            int n4 = pos;
            clusterCounts[n4] = clusterCounts[n4] + 1;
            if (pos == CSOs.size()) {
                pos = -1;
            }
            designations[i9] = pos;
        }
        int newCCount = 0;
        for (int i10 = 0; i10 < clusterCounts.length; ++i10) {
            clusterCounts[i10] = clusterCounts[i10] > 1 ? newCCount++ : -1;
        }
        if (newCCount != clusterCounts.length) {
            double[] tmp = new double[CSOs.size() + 1];
            for (int i11 = 0; i11 < fuzzy.length; ++i11) {
                int d = designations[i11];
                if (d <= 0) continue;
                if (clusterCounts[d] == -1) {
                    List<VecPaired<VecPaired<Vec, Integer>, Double>> knns = allNNs.get(i11);
                    for (int j5 = 1; j5 < weights[i11].length; ++j5) {
                        int jNN = knns.get(j5).getVector().getPair();
                        double[] fuzzy_jNN = fuzzy[jNN];
                        double weight = weights[i11][j5 - 1];
                        for (int z = 0; z < fuzzy[jNN].length; ++z) {
                            int n5 = z;
                            tmp[n5] = tmp[n5] + weight * fuzzy_jNN[z];
                        }
                    }
                    double maxVal = -1.0;
                    int maxIndx = -1;
                    for (int z = 0; z < tmp.length; ++z) {
                        if (!(tmp[z] > maxVal)) continue;
                        maxVal = tmp[z];
                        maxIndx = z;
                    }
                    if (maxIndx == CSOs.size()) {
                        designations[i11] = -1;
                        continue;
                    }
                    designations[i11] = clusterCounts[maxIndx];
                    continue;
                }
                designations[i11] = clusterCounts[d];
            }
        }
        return designations;
    }

    @Override
    public FLAME clone() {
        return new FLAME(this);
    }
}

