package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.DoubleAdder;
import jsat.DataSet;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionUtils;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.utils.IntSet;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/clustering/FLAME.class */
public class FLAME extends ClustererBase implements Parameterized {
    private static final long serialVersionUID = 2393091020100706517L;
    private DistanceMetric dm;
    private int k;
    private int maxIterations;
    private VectorCollection<VecPaired<Vec, Integer>> vc;
    private double stndDevs;
    private double eps;

    public FLAME(DistanceMetric distanceMetric, int i, int i2) {
        this.vc = new DefaultVectorCollection();
        this.stndDevs = 2.5d;
        this.eps = 1.0E-6d;
        setDistanceMetric(distanceMetric);
        setK(i);
        setMaxIterations(i2);
    }

    public FLAME(FLAME flame) {
        this.vc = new DefaultVectorCollection();
        this.stndDevs = 2.5d;
        this.eps = 1.0E-6d;
        this.dm = flame.dm.mo185clone();
        this.maxIterations = flame.maxIterations;
        this.vc = flame.vc.m209clone();
        this.k = flame.k;
        this.stndDevs = flame.stndDevs;
        this.eps = flame.eps;
    }

    public void setMaxIterations(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Must perform a positive number of iterations, not " + i);
        }
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setK(int i) {
        this.k = i;
    }

    public int getK() {
        return this.k;
    }

    public void setEps(double d) {
        if (Double.isNaN(d)) {
            throw new IllegalArgumentException("Eps can not be NaN");
        }
        this.eps = d;
    }

    public double getEps() {
        return this.eps;
    }

    public void setStndDevs(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Standard Deviations must be non negative");
        }
        this.stndDevs = d;
    }

    public double getStndDevs() {
        return this.stndDevs;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setVectorCollectionFactory(VectorCollection<VecPaired<Vec, Integer>> vectorCollection) {
        this.vc = vectorCollection;
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, boolean z, int[] iArr) {
        if (this.k >= dataSet.getSampleSize()) {
            throw new FailedToFitException("Number of k-neighbors (" + this.k + ") can not be larger than the number of datapoints (" + dataSet.getSampleSize() + ")");
        }
        int sampleSize = dataSet.getSampleSize();
        if (iArr == null || iArr.length < dataSet.getSampleSize()) {
            iArr = new int[sampleSize];
        }
        ArrayList arrayList = new ArrayList(sampleSize);
        for (int i = 0; i < dataSet.getSampleSize(); i++) {
            arrayList.add(new VecPaired(dataSet.getDataPoint(i).getNumericalValues(), Integer.valueOf(i)));
        }
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, z);
        this.vc.build(z, arrayList, this.dm);
        List allNearestNeighbors = VectorCollectionUtils.allNearestNeighbors(this.vc, arrayList, this.k + 1, z);
        double[] dArr = new double[arrayList.size()];
        double[][] dArr2 = new double[sampleSize][this.k];
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        for (int i2 = 0; i2 < dArr.length; i2++) {
            List list = (List) allNearestNeighbors.get(i2);
            for (int i3 = 1; i3 < list.size(); i3++) {
                int i4 = i2;
                double d = dArr[i4];
                double doubleValue = ((Double) ((VecPaired) list.get(i3)).getPair()).doubleValue();
                dArr2[i2][i3 - 1] = doubleValue;
                dArr[i4] = d + doubleValue;
            }
            onLineStatistics.add(dArr[i2]);
            double d2 = 0.0d;
            for (int i5 = 0; i5 < this.k; i5++) {
                double min = Math.min(1.0d / Math.pow(dArr2[i2][i5], 2.0d), Double.MAX_VALUE / (this.k + 1));
                dArr2[i2][i5] = min;
                d2 += min;
            }
            for (int i6 = 0; i6 < this.k; i6++) {
                double[] dArr3 = dArr2[i2];
                int i7 = i6;
                dArr3[i7] = dArr3[i7] / d2;
            }
        }
        HashMap hashMap = new HashMap();
        IntSet intSet = new IntSet();
        Arrays.fill(iArr, -1);
        double mean = onLineStatistics.getMean() + (onLineStatistics.getStandardDeviation() * this.stndDevs);
        for (int i8 = 0; i8 < dArr.length; i8++) {
            List list2 = (List) allNearestNeighbors.get(i8);
            boolean z2 = true;
            boolean z3 = true;
            for (int i9 = 1; i9 < list2.size() && (z3 || z2); i9++) {
                if (dArr[i8] > dArr[((Integer) ((VecPaired) ((VecPaired) list2.get(i9)).getVector()).getPair()).intValue()]) {
                    z2 = false;
                } else {
                    z3 = false;
                }
            }
            if (z2) {
                hashMap.put(Integer.valueOf(i8), Integer.valueOf(hashMap.size()));
            } else if (z3 && dArr[i8] > mean) {
                intSet.add((IntSet) Integer.valueOf(i8));
            }
        }
        int size = hashMap.size();
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            List list3 = (List) allNearestNeighbors.get(((Integer) it.next()).intValue());
            int i10 = 1;
            while (true) {
                if (i10 >= list3.size()) {
                    break;
                }
                if (intSet.contains(((VecPaired) ((VecPaired) list3.get(i10)).getVector()).getPair())) {
                    it.remove();
                    break;
                }
                i10++;
            }
        }
        if (size != hashMap.size()) {
            IntSet intSet2 = new IntSet((Set<Integer>) hashMap.keySet());
            hashMap.clear();
            Iterator<Integer> it2 = intSet2.iterator();
            while (it2.hasNext()) {
                hashMap.put(Integer.valueOf(it2.next().intValue()), Integer.valueOf(hashMap.size()));
            }
        }
        Iterator it3 = hashMap.keySet().iterator();
        while (it3.hasNext()) {
            int intValue = ((Integer) it3.next()).intValue();
            iArr[intValue] = ((Integer) hashMap.get(Integer.valueOf(intValue))).intValue();
        }
        double[][] dArr4 = new double[sampleSize][hashMap.size() + 1];
        for (int i11 = 0; i11 < sampleSize; i11++) {
            if (hashMap.containsKey(Integer.valueOf(i11))) {
                dArr4[i11][((Integer) hashMap.get(Integer.valueOf(i11))).intValue()] = 1.0d;
            } else if (intSet.contains(Integer.valueOf(i11))) {
                dArr4[i11][hashMap.size()] = 1.0d;
            } else {
                Arrays.fill(dArr4[i11], 1.0d / (hashMap.size() + 1));
            }
        }
        double[][] dArr5 = new double[sampleSize][hashMap.size() + 1];
        double d3 = Double.POSITIVE_INFINITY;
        for (int i12 = 0; i12 < this.maxIterations; i12++) {
            double[][] dArr6 = dArr4;
            double[][] dArr7 = dArr5;
            DoubleAdder doubleAdder = new DoubleAdder();
            ParallelUtils.run(z, dArr6.length, i13 -> {
                if (intSet.contains(Integer.valueOf(i13)) || hashMap.containsKey(Integer.valueOf(i13))) {
                    return;
                }
                double[] dArr8 = dArr7[i13];
                Arrays.fill(dArr8, 0.0d);
                List list4 = (List) allNearestNeighbors.get(i13);
                double d4 = 0.0d;
                for (int i13 = 1; i13 < dArr2[i13].length; i13++) {
                    int intValue2 = ((Integer) ((VecPaired) ((VecPaired) list4.get(i13)).getVector()).getPair()).intValue();
                    double[] dArr9 = dArr6[intValue2];
                    double d5 = dArr2[i13][i13 - 1];
                    for (int i14 = 0; i14 < dArr6[intValue2].length; i14++) {
                        int i15 = i14;
                        dArr8[i15] = dArr8[i15] + (d5 * dArr9[i14]);
                    }
                }
                for (double d6 : dArr8) {
                    d4 += d6;
                }
                double d7 = 0.0d;
                for (int i16 = 0; i16 < dArr8.length; i16++) {
                    int i17 = i16;
                    dArr8[i17] = dArr8[i17] / (d4 + 1.0E-6d);
                    d7 += Math.abs(dArr6[i13][i16] - dArr8[i16]);
                }
                doubleAdder.add(d7);
            });
            if (Math.abs(d3 - doubleAdder.doubleValue()) < this.eps) {
                break;
            }
            d3 = doubleAdder.doubleValue();
            double[][] dArr8 = dArr4;
            dArr4 = dArr5;
            dArr5 = dArr8;
        }
        int[] iArr2 = new int[sampleSize];
        for (int i14 = 0; i14 < dArr4.length; i14++) {
            int i15 = -1;
            double d4 = 0.0d;
            for (int i16 = 0; i16 < dArr4[i14].length; i16++) {
                if (dArr4[i14][i16] > d4) {
                    d4 = dArr4[i14][i16];
                    i15 = i16;
                }
            }
            if (i15 == -1) {
                i15 = hashMap.size();
            }
            int i17 = i15;
            iArr2[i17] = iArr2[i17] + 1;
            if (i15 == hashMap.size()) {
                i15 = -1;
            }
            iArr[i14] = i15;
        }
        int i18 = 0;
        for (int i19 = 0; i19 < iArr2.length; i19++) {
            if (iArr2[i19] > 1) {
                int i20 = i18;
                i18++;
                iArr2[i19] = i20;
            } else {
                iArr2[i19] = -1;
            }
        }
        if (i18 != iArr2.length) {
            double[] dArr9 = new double[hashMap.size() + 1];
            for (int i21 = 0; i21 < dArr4.length; i21++) {
                int i22 = iArr[i21];
                if (i22 > 0) {
                    if (iArr2[i22] == -1) {
                        List list4 = (List) allNearestNeighbors.get(i21);
                        for (int i23 = 1; i23 < dArr2[i21].length; i23++) {
                            int intValue2 = ((Integer) ((VecPaired) ((VecPaired) list4.get(i23)).getVector()).getPair()).intValue();
                            double[] dArr10 = dArr4[intValue2];
                            double d5 = dArr2[i21][i23 - 1];
                            for (int i24 = 0; i24 < dArr4[intValue2].length; i24++) {
                                int i25 = i24;
                                dArr9[i25] = dArr9[i25] + (d5 * dArr10[i24]);
                            }
                        }
                        double d6 = -1.0d;
                        int i26 = -1;
                        for (int i27 = 0; i27 < dArr9.length; i27++) {
                            if (dArr9[i27] > d6) {
                                d6 = dArr9[i27];
                                i26 = i27;
                            }
                        }
                        if (i26 == hashMap.size()) {
                            iArr[i21] = -1;
                        } else {
                            iArr[i21] = iArr2[i26];
                        }
                    } else {
                        iArr[i21] = iArr2[i22];
                    }
                }
            }
        }
        return iArr;
    }

    @Override // jsat.clustering.ClustererBase
    /* renamed from: clone */
    public FLAME mo114clone() {
        return new FLAME(this);
    }
}
