package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/kmeans/HamerlyKMeans.class */
public class HamerlyKMeans extends KMeans {
    private static final long serialVersionUID = -4960453870335145091L;

    public HamerlyKMeans(DistanceMetric distanceMetric, SeedSelectionMethods.SeedSelection seedSelection, Random random) {
        super(distanceMetric, seedSelection, random);
    }

    public HamerlyKMeans(DistanceMetric distanceMetric, SeedSelectionMethods.SeedSelection seedSelection) {
        this(distanceMetric, seedSelection, RandomUtil.getRandom());
    }

    public HamerlyKMeans() {
        this(new EuclideanDistance(), SeedSelectionMethods.SeedSelection.KPP);
    }

    public HamerlyKMeans(HamerlyKMeans hamerlyKMeans) {
        super(hamerlyKMeans);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.kmeans.KMeans
    public double cluster(DataSet dataSet, List<Double> list, int i, List<Vec> list2, int[] iArr, boolean z, boolean z2, boolean z3, Vec vec) {
        int sampleSize = dataSet.getSampleSize();
        int numNumericalVars = dataSet.getNumNumericalVars();
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, z2);
        Vec dataWeights = vec == null ? dataSet.getDataWeights() : vec;
        List<Vec> dataVectors = dataSet.getDataVectors();
        List<Double> accelerationCache = list == null ? this.dm.getAccelerationCache(dataVectors, z2) : list;
        ArrayList arrayList = new ArrayList(i);
        if (list2.size() != i) {
            list2.clear();
            list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, this.rand, this.seedSelection, z2));
        }
        Vec[] vecArr = new Vec[list2.size()];
        double[] dArr = new double[list2.size()];
        for (int i2 = 0; i2 < list2.size(); i2++) {
            if (list2.get(i2).isSparse()) {
                list2.set(i2, new DenseVector(list2.get(i2)));
            }
            vecArr[i2] = new DenseVector(list2.get(i2));
        }
        Vec[] vecArr2 = new Vec[i];
        Vec[] vecArr3 = new Vec[i];
        Vec[] vecArr4 = new Vec[i];
        for (int i3 = 0; i3 < vecArr4.length; i3++) {
            vecArr4[i3] = new DenseVector(vecArr[0].length());
        }
        AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(i);
        double[] dArr2 = new double[i];
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[sampleSize];
        double[] dArr5 = new double[sampleSize];
        List synchronizedList = Collections.synchronizedList(new ArrayList());
        ThreadLocal<Vec[]> withInitial = ThreadLocal.withInitial(() -> {
            Vec[] vecArr5 = new Vec[list2.size()];
            for (int i4 = 0; i4 < i; i4++) {
                vecArr5[i4] = new DenseVector(numNumericalVars);
            }
            synchronizedList.add(vecArr5);
            return vecArr5;
        });
        Initialize(dataSet, atomicDoubleArray, list2, vecArr3, vecArr2, dArr4, dArr5, iArr, z2, withInitial, dataVectors, accelerationCache, arrayList, dataWeights);
        for (int i4 = 0; i4 < list2.size(); i4++) {
            if (list2.get(i4).isSparse()) {
                list2.set(i4, new DenseVector(list2.get(i4)));
            }
        }
        int i5 = sampleSize;
        int i6 = 0;
        while (i5 > 0) {
            moveCenters(list2, vecArr, vecArr3, vecArr2, atomicDoubleArray, dArr2, arrayList);
            updateS(dArr3, dArr, list2, vecArr, z2, arrayList);
            double[] dArr6 = new double[list2.size()];
            Arrays.fill(dArr6, 0.0d);
            for (int i7 = 0; i7 < sampleSize; i7++) {
                dArr6[iArr[i7]] = Math.max(dArr6[iArr[i7]], dArr4[i7]);
            }
            EnhancedUpdateBounds(list2, dArr, dArr6, dArr3, vecArr, vecArr3, vecArr4, new double[dArr6.length], dArr2, iArr, dArr4, dArr5);
            List<Double> list3 = accelerationCache;
            Vec vec2 = dataWeights;
            i5 = ((Integer) ParallelUtils.run(z2, sampleSize, i8 -> {
                return Integer.valueOf(mainLoopWork(dataSet, i8, dArr3, iArr, dArr4, dArr5, atomicDoubleArray, (Vec[]) withInitial.get(), dataVectors, list3, list2, arrayList, vec2));
            }, (num, num2) -> {
                return Integer.valueOf(num.intValue() + num2.intValue());
            })).intValue();
            ParallelUtils.range(vecArr2.length, z2).forEach(i9 -> {
                Iterator it = synchronizedList.iterator();
                while (it.hasNext()) {
                    Vec[] vecArr5 = (Vec[]) it.next();
                    vecArr2[i9].mutableAdd(vecArr5[i9]);
                    vecArr5[i9].zeroOut();
                }
            });
            i6++;
        }
        if (!z3) {
            return 0.0d;
        }
        if (this.saveCentroidDistance) {
            this.nearestCentroidDist = new double[sampleSize];
        } else {
            this.nearestCentroidDist = null;
        }
        List<Double> list4 = accelerationCache;
        return ((Double) ParallelUtils.run(z2, sampleSize, (i10, i11) -> {
            double d = 0.0d;
            for (int i10 = i10; i10 < i11; i10++) {
                double dist = z ? this.dm.dist(i10, (Vec) list2.get(iArr[i10]), (List) arrayList.get(iArr[i10]), dataVectors, list4) : dArr4[i10];
                d += Math.pow(dist, 2.0d);
                if (this.saveCentroidDistance) {
                    this.nearestCentroidDist[i10] = dist;
                }
            }
            return Double.valueOf(d);
        }, (d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        })).doubleValue();
    }

    private void EnhancedUpdateBounds(List<Vec> list, double[] dArr, double[] dArr2, double[] dArr3, Vec[] vecArr, Vec[] vecArr2, Vec[] vecArr3, double[] dArr4, double[] dArr5, int[] iArr, double[] dArr6, double[] dArr7) {
        double sqrt;
        for (int i = 0; i < list.size(); i++) {
            double d = Double.NEGATIVE_INFINITY;
            IndexTable indexTable = new IndexTable(dArr);
            indexTable.reverse();
            for (int i2 = 0; i2 < indexTable.length(); i2++) {
                int index = indexTable.index(i2);
                if (index != i && (2.0d * dArr2[index]) + dArr3[index] >= dArr[index]) {
                    if (dArr[index] <= d) {
                        break;
                    }
                    vecArr[i].copyTo(vecArr2[i]);
                    list.get(index).copyTo(vecArr3[i]);
                    vecArr2[i].mutableSubtract(vecArr[index]);
                    vecArr3[i].mutableSubtract(vecArr[index]);
                    double dot = vecArr2[i].dot(vecArr3[i]) / (dArr[index] * dArr[index]);
                    vecArr2[i].mutableMultiply(-1.0d);
                    vecArr2[i].mutableAdd(dot, vecArr3[i]);
                    double pNorm = (vecArr3[i].pNorm(2.0d) * 2.0d) / dArr[index];
                    double d2 = 1.0d - (2.0d * dot);
                    double d3 = (dArr2[i] * 2.0d) / dArr[index];
                    if (pNorm <= d3) {
                        sqrt = Math.max(0.0d, Math.min(2.0d, 2.0d * (d3 - d2)));
                    } else {
                        if (d2 > d3) {
                            d2 -= 1.0d;
                        }
                        double sqrt2 = Math.sqrt((pNorm * pNorm) + (d2 * d2));
                        double d4 = sqrt2 * sqrt2;
                        sqrt = (2.0d * ((pNorm * d3) - (d2 * Math.sqrt(d4 - (d3 * d3))))) / d4;
                    }
                    d = Math.max(sqrt * (dArr[index] / 2.0d), d);
                }
            }
            dArr4[i] = d;
        }
        UpdateBounds(dArr5, iArr, dArr6, dArr7, dArr4);
    }

    private int mainLoopWork(DataSet dataSet, int i, double[] dArr, int[] iArr, double[] dArr2, double[] dArr3, AtomicDoubleArray atomicDoubleArray, Vec[] vecArr, List<Vec> list, List<Double> list2, List<Vec> list3, List<List<Double>> list4, Vec vec) {
        int PointAllCtrs;
        int i2 = iArr[i];
        double max = Math.max(dArr[i2] / 2.0d, dArr3[i]);
        if (dArr2[i] <= max) {
            return 0;
        }
        Vec vec2 = list.get(i);
        dArr2[i] = this.dm.dist(i, list3.get(i2), list4.get(i2), list, list2);
        if (dArr2[i] <= max || i2 == (PointAllCtrs = PointAllCtrs(vec2, i, list3, iArr, dArr2, dArr3, list, list2, list4))) {
            return 0;
        }
        double d = vec.get(i);
        atomicDoubleArray.addAndGet(i2, -d);
        atomicDoubleArray.addAndGet(PointAllCtrs, d);
        vecArr[i2].mutableSubtract(d, vec2);
        vecArr[PointAllCtrs].mutableAdd(d, vec2);
        return 1;
    }

    private void updateS(double[] dArr, double[] dArr2, List<Vec> list, Vec[] vecArr, boolean z, List<List<Double>> list2) {
        Arrays.fill(dArr, Double.MAX_VALUE);
        DoubleList doubleList = list2.get(0).isEmpty() ? null : new DoubleList(list2.size());
        if (doubleList != null) {
            Iterator<List<Double>> it = list2.iterator();
            while (it.hasNext()) {
                doubleList.addAll(it.next());
            }
        }
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new double[dArr.length];
        });
        ParallelUtils.run(z, list.size(), i -> {
            double[] dArr3 = (double[]) withInitial.get();
            Arrays.fill(dArr3, Double.POSITIVE_INFINITY);
            dArr2[i] = this.dm.dist(vecArr[i], (Vec) list.get(i));
            for (int i = i + 1; i < list.size(); i++) {
                double dist = this.dm.dist(i, i, (List<? extends Vec>) list, doubleList);
                dArr3[i] = Math.min(dArr3[i], dist);
                dArr3[i] = Math.min(dArr3[i], dist);
            }
            synchronized (dArr) {
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    dArr[i2] = Math.min(dArr[i2], dArr3[i2]);
                }
            }
        });
    }

    private void Initialize(DataSet dataSet, AtomicDoubleArray atomicDoubleArray, List<Vec> list, Vec[] vecArr, Vec[] vecArr2, double[] dArr, double[] dArr2, int[] iArr, boolean z, ThreadLocal<Vec[]> threadLocal, List<Vec> list2, List<Double> list3, List<List<Double>> list4, Vec vec) {
        for (int i = 0; i < list.size(); i++) {
            vecArr2[i] = new DenseVector(list.get(0).length());
            vecArr[i] = vecArr2[i].mo46clone();
            if (this.dm.supportsAcceleration()) {
                list4.add(this.dm.getQueryInfo(list.get(i)));
            } else {
                list4.add(Collections.EMPTY_LIST);
            }
        }
        ParallelUtils.run(z, dArr.length, (i2, i3) -> {
            Vec[] vecArr3 = (Vec[]) threadLocal.get();
            for (int i2 = i2; i2 < i3; i2++) {
                Vec vec2 = (Vec) list2.get(i2);
                int PointAllCtrs = PointAllCtrs(vec2, i2, list, iArr, dArr, dArr2, list2, list3, list4);
                double d = vec.get(i2);
                atomicDoubleArray.addAndGet(PointAllCtrs, d);
                vecArr3[PointAllCtrs].mutableAdd(d, vec2);
            }
            for (int i3 = 0; i3 < vecArr2.length; i3++) {
                synchronized (vecArr2[i3]) {
                    vecArr2[i3].mutableAdd(vecArr3[i3]);
                }
                vecArr3[i3].zeroOut();
            }
        });
    }

    private int PointAllCtrs(Vec vec, int i, List<Vec> list, int[] iArr, double[] dArr, double[] dArr2, List<Vec> list2, List<Double> list3, List<List<Double>> list4) {
        double d = Double.POSITIVE_INFINITY;
        double d2 = Double.MAX_VALUE;
        int i2 = -1;
        for (int i3 = 0; i3 < list.size(); i3++) {
            double dist = this.dm.dist(i, list.get(i3), list4.get(i3), list2, list3);
            if (dist < d) {
                if (dist < d2) {
                    d = d2;
                    d2 = dist;
                    i2 = i3;
                } else {
                    d = dist;
                }
            }
        }
        iArr[i] = i2;
        dArr[i] = d2;
        dArr2[i] = d;
        return i2;
    }

    private void moveCenters(List<Vec> list, Vec[] vecArr, Vec[] vecArr2, Vec[] vecArr3, AtomicDoubleArray atomicDoubleArray, double[] dArr, List<List<Double>> list2) {
        for (int i = 0; i < list.size(); i++) {
            double d = atomicDoubleArray.get(i);
            list.get(i).copyTo(vecArr[i]);
            if (d > 0.0d) {
                vecArr3[i].copyTo(vecArr2[i]);
                vecArr2[i].mutableDivide(d);
            } else {
                vecArr3[i].zeroOut();
                vecArr2[i].zeroOut();
            }
            dArr[i] = this.dm.dist(list.get(i), vecArr2[i]);
            vecArr2[i].copyTo(list.get(i));
            if (this.dm.supportsAcceleration()) {
                list2.set(i, this.dm.getQueryInfo(list.get(i)));
            }
        }
    }

    private void UpdateBounds(double[] dArr, int[] iArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        double d2 = -1.7976931348623157E308d;
        int i2 = -1;
        for (int i3 = 0; i3 < dArr.length; i3++) {
            double d3 = dArr[i3];
            if (d3 > d) {
                if (d3 > d2) {
                    d = d2;
                    i = i2;
                    d2 = d3;
                    i2 = i3;
                } else {
                    d = d3;
                    i = i3;
                }
            }
        }
        int i4 = i2;
        int i5 = i;
        for (int i6 = 0; i6 < dArr2.length; i6++) {
            int i7 = iArr[i6];
            int i8 = i6;
            dArr2[i8] = dArr2[i8] + dArr[i7];
            if (i4 == i7) {
                int i9 = i6;
                dArr3[i9] = dArr3[i9] - Math.min(dArr[i5], dArr4[i7]);
            } else {
                int i10 = i6;
                dArr3[i10] = dArr3[i10] - Math.min(dArr[i4], dArr4[i7]);
            }
        }
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public HamerlyKMeans mo114clone() {
        return new HamerlyKMeans(this);
    }
}
