package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DenseSparseMetric;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/kmeans/ElkanKMeans.class */
public class ElkanKMeans extends KMeans {
    private static final long serialVersionUID = -1629432283103273051L;
    private DenseSparseMetric dmds;
    private boolean useDenseSparse;

    public ElkanKMeans(DistanceMetric distanceMetric, Random random, SeedSelectionMethods.SeedSelection seedSelection) {
        super(distanceMetric, seedSelection, random);
        this.useDenseSparse = false;
        if (!distanceMetric.isSubadditive()) {
            throw new ClusterFailureException("KMeans implementation requires the triangle inequality");
        }
    }

    public ElkanKMeans(DistanceMetric distanceMetric, Random random) {
        this(distanceMetric, random, DEFAULT_SEED_SELECTION);
    }

    public ElkanKMeans(DistanceMetric distanceMetric) {
        this(distanceMetric, RandomUtil.getRandom());
    }

    public ElkanKMeans() {
        this(new EuclideanDistance());
    }

    public ElkanKMeans(ElkanKMeans elkanKMeans) {
        super(elkanKMeans);
        this.useDenseSparse = false;
        if (elkanKMeans.dmds != null) {
            this.dmds = (DenseSparseMetric) elkanKMeans.dmds.mo185clone();
        }
        this.useDenseSparse = elkanKMeans.useDenseSparse;
    }

    public void setUseDenseSparse(boolean z) {
        this.useDenseSparse = z;
    }

    public boolean isUseDenseSparse() {
        return this.useDenseSparse;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.kmeans.KMeans
    public double cluster(DataSet dataSet, List<Double> list, final int i, List<Vec> list2, int[] iArr, boolean z, boolean z2, boolean z3, Vec vec) {
        try {
            int sampleSize = dataSet.getSampleSize();
            final int numNumericalVars = dataSet.getNumNumericalVars();
            if (sampleSize < i) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            Vec dataWeights = vec == null ? dataSet.getDataWeights() : vec;
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
            List<Vec> dataVectors = dataSet.getDataVectors();
            ArrayList arrayList = new ArrayList(i);
            List<Double> accelerationCache = list == null ? this.dm.getAccelerationCache(dataVectors, z2) : list;
            if (list2.size() != i) {
                list2.clear();
                list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, this.rand, this.seedSelection, z2));
            }
            for (int i2 = 0; i2 < list2.size(); i2++) {
                if (list2.get(i2).isSparse()) {
                    list2.set(i2, new DenseVector(list2.get(i2)));
                }
            }
            double[][] dArr = new double[sampleSize][i];
            double[] dArr2 = new double[sampleSize];
            double[][] dArr3 = new double[i][i];
            double[] dArr4 = new double[i];
            calculateCentroidDistances(i, dArr3, list2, dArr4, null, z2);
            AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(i);
            Vec[] vecArr = new Vec[i];
            Vec[] vecArr2 = new Vec[i];
            for (int i3 = 0; i3 < i; i3++) {
                vecArr[i3] = list2.get(i3).mo46clone();
                if (this.dm.supportsAcceleration()) {
                    arrayList.add(this.dm.getQueryInfo(list2.get(i3)));
                } else {
                    arrayList.add(Collections.EMPTY_LIST);
                }
                vecArr2[i3] = new DenseVector(numNumericalVars);
            }
            if ((this.dm instanceof DenseSparseMetric) && this.useDenseSparse) {
                this.dmds = (DenseSparseMetric) this.dm;
            }
            double[] dArr5 = this.dmds != null ? new double[list2.size()] : null;
            int i4 = 2;
            AtomicBoolean atomicBoolean = new AtomicBoolean(true);
            boolean[] zArr = new boolean[sampleSize];
            ThreadLocal<Vec[]> threadLocal = new ThreadLocal<Vec[]>() { // from class: jsat.clustering.kmeans.ElkanKMeans.1
                /* JADX INFO: Access modifiers changed from: protected */
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.lang.ThreadLocal
                public Vec[] initialValue() {
                    Vec[] vecArr3 = new Vec[i];
                    for (int i5 = 0; i5 < vecArr3.length; i5++) {
                        vecArr3[i5] = new DenseVector(numNumericalVars);
                    }
                    return vecArr3;
                }
            };
            initialClusterSetUp(i, sampleSize, dataVectors, list2, dArr, dArr2, dArr3, iArr, atomicDoubleArray, vecArr2, accelerationCache, arrayList, threadLocal, z2, dataWeights);
            int i5 = this.MaxIterLimit;
            while (true) {
                if (!atomicBoolean.get() && i4 <= 0) {
                    break;
                }
                int i6 = i5;
                i5--;
                if (i6 < 0) {
                    break;
                }
                i4--;
                atomicBoolean.set(false);
                if (i5 < this.MaxIterLimit - 1) {
                    calculateCentroidDistances(i, dArr3, list2, dArr4, dArr5, z2);
                }
                new CountDownLatch(SystemInfo.LogicalCores);
                List<Double> list3 = accelerationCache;
                Vec vec2 = dataWeights;
                ParallelUtils.run(z2, sampleSize, i7 -> {
                    if (dArr2[i7] <= dArr4[iArr[i7]]) {
                        return;
                    }
                    Vec vec3 = (Vec) dataVectors.get(i7);
                    for (int i7 = 0; i7 < i; i7++) {
                        if (i7 != iArr[i7] && dArr2[i7] > dArr[i7][i7] && dArr2[i7] > dArr3[iArr[i7]][i7] * 0.5d) {
                            step3aBoundsUpdate(dataVectors, zArr, i7, vec3, list2, iArr, dArr2, dArr, dArr5, list3, arrayList);
                            step3bUpdate(dataVectors, dArr2, i7, dArr, i7, dArr3, iArr, vec3, list2, threadLocal, atomicDoubleArray, atomicBoolean, dArr5, list3, arrayList, vec2);
                        }
                    }
                    step4UpdateCentroids(vecArr2, threadLocal);
                });
                step5_6_distanceMovedBoundsUpdate(i, vecArr, list2, vecArr2, atomicDoubleArray, sampleSize, dArr, dArr2, iArr, zArr, arrayList, z2);
            }
            double d = 0.0d;
            if (z3) {
                if (this.saveCentroidDistance) {
                    this.nearestCentroidDist = new double[sampleSize];
                } else {
                    this.nearestCentroidDist = null;
                }
                if (z) {
                    for (int i8 = 0; i8 < sampleSize; i8++) {
                        double dist = this.dm.dist(i8, list2.get(iArr[i8]), arrayList.get(iArr[i8]), dataVectors, accelerationCache);
                        d += Math.pow(dist, 2.0d);
                        if (this.saveCentroidDistance) {
                            this.nearestCentroidDist[i8] = dist;
                        }
                    }
                } else {
                    for (int i9 = 0; i9 < sampleSize; i9++) {
                        d += Math.pow(dArr2[i9], 2.0d);
                        if (this.saveCentroidDistance) {
                            this.nearestCentroidDist[i9] = dArr2[i9];
                        }
                    }
                }
            }
            return d;
        } catch (Exception e) {
            Logger.getLogger(ElkanKMeans.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            return Double.MAX_VALUE;
        }
    }

    private void initialClusterSetUp(int i, int i2, List<Vec> list, List<Vec> list2, double[][] dArr, double[] dArr2, double[][] dArr3, int[] iArr, AtomicDoubleArray atomicDoubleArray, Vec[] vecArr, List<Double> list3, List<List<Double>> list4, ThreadLocal<Vec[]> threadLocal, boolean z, Vec vec) {
        ParallelUtils.run(z, i2, (i3, i4) -> {
            Vec[] vecArr2 = (Vec[]) threadLocal.get();
            boolean[] zArr = new boolean[i];
            for (int i3 = i3; i3 < i4; i3++) {
                Vec vec2 = (Vec) list.get(i3);
                double d = Double.MAX_VALUE;
                int i4 = -1;
                Arrays.fill(zArr, false);
                for (int i5 = 0; i5 < i; i5++) {
                    if (!zArr[i5]) {
                        double dist = this.dm.dist(i3, (Vec) list2.get(i5), (List) list4.get(i5), list, list3);
                        dArr[i3][i5] = dist;
                        if (dist < d) {
                            dArr2[i3] = dist;
                            d = dist;
                            i4 = i5;
                            for (int i6 = i5 + 1; i6 < i; i6++) {
                                if (dArr3[i5][i6] >= 2.0d * dist) {
                                    zArr[i6] = true;
                                }
                            }
                        }
                    }
                }
                iArr[i3] = i4;
                double d2 = vec.get(i3);
                atomicDoubleArray.addAndGet(i4, d2);
                vecArr2[i4].mutableAdd(d2, vec2);
            }
            for (int i7 = 0; i7 < vecArr2.length; i7++) {
                synchronized (vecArr[i7]) {
                    vecArr[i7].mutableAdd(vecArr2[i7]);
                }
                vecArr2[i7].zeroOut();
            }
        });
    }

    private void step4UpdateCentroids(Vec[] vecArr, ThreadLocal<Vec[]> threadLocal) {
        Vec[] vecArr2 = threadLocal.get();
        for (int i = 0; i < vecArr2.length; i++) {
            if (vecArr2[i].nnz() != 0) {
                synchronized (vecArr[i]) {
                    vecArr[i].mutableAdd(vecArr2[i]);
                }
                vecArr2[i].zeroOut();
            }
        }
    }

    private void step5_6_distanceMovedBoundsUpdate(int i, Vec[] vecArr, List<Vec> list, Vec[] vecArr2, AtomicDoubleArray atomicDoubleArray, int i2, double[][] dArr, double[] dArr2, int[] iArr, boolean[] zArr, List<List<Double>> list2, boolean z) {
        double[] dArr3 = new double[i];
        ParallelUtils.run(z, i, i3 -> {
            ((Vec) list.get(i3)).copyTo(vecArr[i3]);
            vecArr2[i3].copyTo((Vec) list.get(i3));
            if (atomicDoubleArray.get(i3) <= 1.0E-14d) {
                ((Vec) list.get(i3)).zeroOut();
            } else {
                ((Vec) list.get(i3)).mutableDivide(atomicDoubleArray.get(i3));
            }
            dArr3[i3] = this.dm.dist(vecArr[i3], (Vec) list.get(i3));
            if (this.dm.supportsAcceleration()) {
                list2.set(i3, this.dm.getQueryInfo((Vec) list.get(i3)));
            }
            for (int i3 = 0; i3 < i2; i3++) {
                dArr[i3][i3] = Math.max(dArr[i3][i3] - dArr3[i3], 0.0d);
            }
        });
        ParallelUtils.run(z, i2, (i4, i5) -> {
            for (int i4 = i4; i4 < i5; i4++) {
                int i5 = i4;
                dArr2[i5] = dArr2[i5] + dArr3[iArr[i4]];
                zArr[i4] = true;
            }
        });
    }

    private void step3aBoundsUpdate(List<Vec> list, boolean[] zArr, int i, Vec vec, List<Vec> list2, int[] iArr, double[] dArr, double[][] dArr2, double[] dArr3, List<Double> list3, List<List<Double>> list4) {
        if (zArr[i]) {
            zArr[i] = false;
            int i2 = iArr[i];
            double dist = this.dmds == null ? this.dm.dist(i, list2.get(i2), list4.get(i2), list, list3) : this.dmds.dist(dArr3[i2], list2.get(i2), vec);
            dArr2[i][i2] = dist;
            dArr[i] = dist;
        }
    }

    private void step3bUpdate(List<Vec> list, double[] dArr, int i, double[][] dArr2, int i2, double[][] dArr3, int[] iArr, Vec vec, List<Vec> list2, ThreadLocal<Vec[]> threadLocal, AtomicDoubleArray atomicDoubleArray, AtomicBoolean atomicBoolean, double[] dArr4, List<Double> list3, List<List<Double>> list4, Vec vec2) {
        if (dArr[i] > dArr2[i][i2] || dArr[i] > dArr3[iArr[i]][i2] / 2.0d) {
            double dist = this.dmds == null ? this.dm.dist(i, list2.get(i2), list4.get(i2), list, list3) : this.dmds.dist(dArr4[i2], list2.get(i2), vec);
            dArr2[i][i2] = dist;
            if (dist < dArr[i]) {
                Vec[] vecArr = threadLocal.get();
                double d = vec2.get(i);
                vecArr[iArr[i]].mutableSubtract(d, vec);
                atomicDoubleArray.addAndGet(iArr[i], -d);
                vecArr[i2].mutableAdd(d, vec);
                atomicDoubleArray.addAndGet(i2, d);
                iArr[i] = i2;
                dArr[i] = dist;
                atomicBoolean.set(true);
            }
        }
    }

    private void calculateCentroidDistances(int i, double[][] dArr, List<Vec> list, double[] dArr2, double[] dArr3, boolean z) {
        List<Double> accelerationCache = this.dm.supportsAcceleration() ? this.dm.getAccelerationCache(list) : null;
        ParallelUtils.run(z, i, i2 -> {
            for (int i2 = i2 + 1; i2 < i; i2++) {
                double[] dArr4 = dArr[i2];
                double dist = this.dm.dist(i2, i2, (List<? extends Vec>) list, (List<Double>) accelerationCache);
                dArr[i2][i2] = dist;
                dArr4[i2] = dist;
            }
            if (dArr3 != null) {
                dArr3[i2] = this.dmds.getVectorConstant((Vec) list.get(i2));
            }
        });
        for (int i3 = 0; i3 < i; i3++) {
            double d = Double.MAX_VALUE;
            for (int i4 = 0; i4 < i; i4++) {
                if (i4 != i3) {
                    d = Math.min(d, dArr[i3][i4]);
                }
            }
            dArr2[i3] = d / 2.0d;
        }
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public ElkanKMeans mo114clone() {
        return new ElkanKMeans(this);
    }
}
