package jsat.clustering.kmeans;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/clustering/kmeans/ElkanKernelKMeans.class */
public class ElkanKernelKMeans extends KernelKMeans {
    private static final long serialVersionUID = 4998832201379993827L;
    private double[][] centroidSelfDistances;
    private double[][] centroidPairDots;

    public ElkanKernelKMeans(KernelTrick kernelTrick) {
        super(kernelTrick);
    }

    public ElkanKernelKMeans(ElkanKernelKMeans elkanKernelKMeans) {
        super(elkanKernelKMeans);
    }

    @Override // jsat.clustering.kmeans.KernelKMeans
    public int findClosestCluster(Vec vec, List<Double> list) {
        double d = Double.MAX_VALUE;
        int i = -1;
        boolean[] zArr = new boolean[this.meanSqrdNorms.length];
        Arrays.fill(zArr, false);
        for (int i2 = 0; i2 < this.meanSqrdNorms.length; i2++) {
            if (this.ownes[i2] > 1.0E-15d && !zArr[i2]) {
                double distance = distance(vec, list, i2);
                if (distance < d) {
                    d = distance;
                    i = i2;
                }
                for (int i3 = i2 + 1; i3 < this.meanSqrdNorms.length; i3++) {
                    if (this.centroidSelfDistances[i2][i3] >= 2.0d * distance) {
                        zArr[i3] = true;
                    }
                }
            }
        }
        return i;
    }

    private void update_centroid_pair_dots(int[] iArr, int[] iArr2, boolean z) {
        int size = this.X.size();
        ParallelUtils.run(z, size, (i, i2) -> {
            double[][] dArr = new double[this.centroidPairDots.length][this.centroidPairDots.length];
            for (int i = i; i < i2; i++) {
                double d = this.W.get(i);
                int i2 = iArr[i];
                int i3 = iArr2[i];
                for (int i4 = i; i4 < size; i4++) {
                    int i5 = iArr[i4];
                    int i6 = iArr2[i4];
                    if (i2 != i3 || i5 != i6) {
                        double eval = d * this.W.get(i4) * this.kernel.eval(i, i4, this.X, this.accel);
                        if (i2 >= 0 && i5 >= 0) {
                            double[] dArr2 = dArr[i2];
                            dArr2[i5] = dArr2[i5] - eval;
                            double[] dArr3 = dArr[i5];
                            dArr3[i2] = dArr3[i2] - eval;
                        }
                        double[] dArr4 = dArr[i3];
                        dArr4[i6] = dArr4[i6] + eval;
                        double[] dArr5 = dArr[i6];
                        dArr5[i3] = dArr5[i3] + eval;
                    }
                }
            }
            for (int i7 = 0; i7 < dArr.length; i7++) {
                double[] dArr6 = this.centroidPairDots[i7];
                synchronized (dArr6) {
                    for (int i8 = 0; i8 < dArr[i7].length; i8++) {
                        int i9 = i8;
                        dArr6[i9] = dArr6[i9] + dArr[i7][i8];
                    }
                }
            }
        });
    }

    protected double cluster(DataSet dataSet, int i, int[] iArr, boolean z, boolean z2) {
        try {
            int sampleSize = dataSet.getSampleSize();
            if (sampleSize < i) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            this.X = dataSet.getDataVectors();
            setup(i, iArr, dataSet.getDataWeights());
            double[][] dArr = new double[sampleSize][i];
            double[] dArr2 = new double[sampleSize];
            this.centroidSelfDistances = new double[i][i];
            this.centroidPairDots = new double[i][i];
            double[] dArr3 = new double[i];
            calculateCentroidDistances(i, this.centroidSelfDistances, dArr3, iArr, null, z2);
            int[] iArr2 = new int[sampleSize];
            int i2 = 2;
            AtomicBoolean atomicBoolean = new AtomicBoolean(true);
            boolean[] zArr = new boolean[sampleSize];
            initialClusterSetUp(i, sampleSize, dArr, dArr2, this.centroidSelfDistances, iArr, z2);
            int i3 = this.maximumIterations;
            while (true) {
                if (!atomicBoolean.get() && i2 <= 0) {
                    break;
                }
                int i4 = i3;
                i3--;
                if (i4 < 0) {
                    break;
                }
                i2--;
                atomicBoolean.set(false);
                if (i3 < this.maximumIterations - 1) {
                    calculateCentroidDistances(i, this.centroidSelfDistances, dArr3, iArr, iArr2, z2);
                }
                System.arraycopy(iArr, 0, iArr2, 0, iArr.length);
                new CountDownLatch(SystemInfo.LogicalCores);
                ParallelUtils.run(z2, sampleSize, i5 -> {
                    if (dArr2[i5] <= dArr3[iArr[i5]]) {
                        return;
                    }
                    for (int i5 = 0; i5 < i; i5++) {
                        if (i5 != iArr[i5] && dArr2[i5] > dArr[i5][i5] && dArr2[i5] > this.centroidSelfDistances[iArr[i5]][i5] * 0.5d) {
                            step3aBoundsUpdate(zArr, i5, iArr, dArr2, dArr);
                            step3bUpdate(dArr2, i5, dArr, i5, this.centroidSelfDistances, iArr, atomicBoolean);
                        }
                    }
                });
                step4_5_6_distanceMovedBoundsUpdate(i, sampleSize, dArr, dArr2, iArr, zArr, z2);
            }
            double d = 0.0d;
            if (z) {
                for (int i6 = 0; i6 < sampleSize; i6++) {
                    d += Math.pow(dArr2[i6], 2.0d);
                }
            } else {
                for (int i7 = 0; i7 < sampleSize; i7++) {
                    d += Math.pow(dArr2[i7], 2.0d);
                }
            }
            return d;
        } catch (Exception e) {
            e.printStackTrace();
            throw new FailedToFitException(e);
        }
    }

    private void initialClusterSetUp(int i, int i2, double[][] dArr, double[] dArr2, double[][] dArr3, int[] iArr, boolean z) {
        ParallelUtils.run(z, i2, (i3, i4) -> {
            boolean[] zArr = new boolean[i];
            for (int i3 = i3; i3 < i4; i3++) {
                double d = Double.MAX_VALUE;
                int i4 = -1;
                Arrays.fill(zArr, false);
                for (int i5 = 0; i5 < i; i5++) {
                    if (!zArr[i5]) {
                        double distance = distance(i3, i5, iArr);
                        dArr[i3][i5] = distance;
                        if (distance < d) {
                            dArr2[i3] = distance;
                            d = distance;
                            i4 = i5;
                            for (int i6 = i5 + 1; i6 < i; i6++) {
                                if (dArr3[i5][i6] >= 2.0d * distance) {
                                    zArr[i6] = true;
                                }
                            }
                        }
                    }
                }
                this.newDesignations[i3] = i4;
            }
        });
    }

    private int step4_5_6_distanceMovedBoundsUpdate(int i, int i2, double[][] dArr, double[] dArr2, int[] iArr, boolean[] zArr, boolean z) {
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[this.meanSqrdNorms.length];
        for (int i3 = 0; i3 < this.meanSqrdNorms.length; i3++) {
            dArr4[i3] = this.meanSqrdNorms[i3] * this.normConsts[i3];
        }
        int intValue = ((Integer) ParallelUtils.run(z, i2, (i4, i5) -> {
            double[] dArr5 = new double[i];
            double[] dArr6 = new double[i];
            int i4 = 0;
            for (int i5 = i4; i5 < i5; i5++) {
                i4 += updateMeansFromChange(i5, iArr, dArr5, dArr6);
            }
            synchronized (iArr) {
                applyMeanUpdates(dArr5, dArr6);
            }
            return Integer.valueOf(i4);
        }, (num, num2) -> {
            return Integer.valueOf(num.intValue() + num2.intValue());
        })).intValue();
        updateNormConsts();
        ParallelUtils.run(z, i, i6 -> {
            dArr3[i6] = meanToMeanDistance(i6, i6, this.newDesignations, iArr, dArr4[i6], z);
        });
        ParallelUtils.run(z, i, i7 -> {
            for (int i7 = 0; i7 < i2; i7++) {
                dArr[i7][i7] = Math.max(dArr[i7][i7] - dArr3[i7], 0.0d);
            }
        });
        System.arraycopy(this.newDesignations, 0, iArr, 0, i2);
        ParallelUtils.run(z, i2, (i8, i9) -> {
            for (int i8 = i8; i8 < i9; i8++) {
                int i9 = i8;
                dArr2[i9] = dArr2[i9] + dArr3[iArr[i8]];
                zArr[i8] = true;
            }
        });
        return intValue;
    }

    private void step3aBoundsUpdate(boolean[] zArr, int i, int[] iArr, double[] dArr, double[][] dArr2) {
        if (zArr[i]) {
            zArr[i] = false;
            int i2 = iArr[i];
            double distance = distance(i, i2, iArr);
            dArr2[i][i2] = distance;
            dArr[i] = distance;
        }
    }

    private void step3bUpdate(double[] dArr, int i, double[][] dArr2, int i2, double[][] dArr3, int[] iArr, AtomicBoolean atomicBoolean) {
        if (dArr[i] > dArr2[i][i2] || dArr[i] > dArr3[iArr[i]][i2] / 2.0d) {
            double distance = distance(i, i2, iArr);
            dArr2[i][i2] = distance;
            if (distance < dArr[i]) {
                this.newDesignations[i] = i2;
                dArr[i] = distance;
                atomicBoolean.lazySet(true);
            }
        }
    }

    private void calculateCentroidDistances(int i, double[][] dArr, double[] dArr2, int[] iArr, int[] iArr2, boolean z) {
        if (iArr2 == null) {
            iArr2 = new int[iArr.length];
            Arrays.fill(iArr2, -1);
        }
        update_centroid_pair_dots(iArr2, iArr, z);
        double[] dArr3 = new double[i];
        for (int i2 = 0; i2 < iArr.length; i2++) {
            int i3 = iArr[i2];
            dArr3[i3] = dArr3[i3] + this.W.get(i2);
        }
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = i4 + 1; i5 < i; i5++) {
                double sqrt = Math.sqrt(Math.max(0.0d, ((this.meanSqrdNorms[i4] * this.normConsts[i4]) + (this.meanSqrdNorms[i5] * this.normConsts[i5])) - (2.0d * (this.centroidPairDots[i4][i5] / (dArr3[i4] * dArr3[i5])))));
                dArr[i4][i5] = sqrt;
                dArr[i5][i4] = sqrt;
            }
        }
        for (int i6 = 0; i6 < i; i6++) {
            double d = Double.MAX_VALUE;
            for (int i7 = 0; i7 < i; i7++) {
                if (i6 != i7) {
                    d = Math.min(d, dArr[i6][i7]);
                }
            }
            dArr2[i6] = d / 2.0d;
        }
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, boolean z, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < i) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        cluster(dataSet, i, iArr, false, z);
        return iArr;
    }

    @Override // jsat.clustering.kmeans.KernelKMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public ElkanKernelKMeans mo114clone() {
        return new ElkanKernelKMeans(this);
    }
}
