package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.LongAdder;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/kmeans/NaiveKMeans.class */
public class NaiveKMeans extends KMeans {
    private static final long serialVersionUID = 6164910874898843069L;

    public NaiveKMeans() {
        this(new EuclideanDistance());
    }

    public NaiveKMeans(DistanceMetric distanceMetric) {
        this(distanceMetric, SeedSelectionMethods.SeedSelection.KPP);
    }

    public NaiveKMeans(DistanceMetric distanceMetric, SeedSelectionMethods.SeedSelection seedSelection) {
        this(distanceMetric, seedSelection, RandomUtil.getRandom());
    }

    public NaiveKMeans(DistanceMetric distanceMetric, SeedSelectionMethods.SeedSelection seedSelection, Random random) {
        super(distanceMetric, seedSelection, random);
    }

    public NaiveKMeans(NaiveKMeans naiveKMeans) {
        super(naiveKMeans);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.kmeans.KMeans
    public double cluster(DataSet dataSet, List<Double> list, final int i, final List<Vec> list2, int[] iArr, boolean z, boolean z2, boolean z3, Vec vec) {
        TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet, z2);
        Vec dataWeights = vec == null ? dataSet.getDataWeights() : vec;
        int sampleSize = dataSet.getSampleSize() / SystemInfo.LogicalCores;
        List<Vec> dataVectors = dataSet.getDataVectors();
        List<Double> accelerationCache = list == null ? this.dm.getAccelerationCache(dataVectors, z2) : list;
        if (list2.size() != i) {
            list2.clear();
            list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, this.dm, accelerationCache, this.rand, this.seedSelection, z2));
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < list2.size(); i2++) {
            if (this.dm.supportsAcceleration()) {
                arrayList.add(this.dm.getQueryInfo(list2.get(i2)));
            } else {
                arrayList.add(Collections.EMPTY_LIST);
            }
            if (list2.get(i2).isSparse()) {
                list2.set(i2, new DenseVector(list2.get(i2)));
            }
        }
        ArrayList arrayList2 = new ArrayList(list2.size());
        AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(list2.size());
        for (int i3 = 0; i3 < i; i3++) {
            arrayList2.add(new DenseVector(list2.get(0).length()));
        }
        LongAdder longAdder = new LongAdder();
        ThreadLocal<Vec[]> threadLocal = new ThreadLocal<Vec[]>() { // from class: jsat.clustering.kmeans.NaiveKMeans.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public Vec[] initialValue() {
                Vec[] vecArr = new Vec[i];
                for (int i4 = 0; i4 < i; i4++) {
                    vecArr[i4] = new DenseVector(((Vec) list2.get(0)).length());
                }
                return vecArr;
            }
        };
        int sampleSize2 = dataSet.getSampleSize();
        Arrays.fill(iArr, -1);
        do {
            longAdder.reset();
            List<Double> list3 = accelerationCache;
            Vec vec2 = dataWeights;
            ParallelUtils.run(z2, sampleSize2, (i4, i5) -> {
                Vec[] vecArr = (Vec[]) threadLocal.get();
                for (int i4 = i4; i4 < i5; i4++) {
                    Vec vec3 = (Vec) dataVectors.get(i4);
                    double d = Double.POSITIVE_INFINITY;
                    int i5 = -1;
                    for (int i6 = 0; i6 < list2.size(); i6++) {
                        double dist = this.dm.dist(i4, (Vec) list2.get(i6), (List) arrayList.get(i6), dataVectors, list3);
                        if (dist < d) {
                            d = dist;
                            i5 = i6;
                        }
                    }
                    if (iArr[i4] != i5) {
                        double d2 = vec2.get(i4);
                        vecArr[i5].mutableAdd(d2, vec3);
                        atomicDoubleArray.addAndGet(i5, d2);
                        if (iArr[i4] >= 0) {
                            vecArr[iArr[i4]].mutableSubtract(d2, vec3);
                            atomicDoubleArray.getAndAdd(iArr[i4], -d2);
                        }
                        iArr[i4] = i5;
                        longAdder.increment();
                    }
                }
                for (int i7 = 0; i7 < vecArr.length; i7++) {
                    synchronized (((Vec) arrayList2.get(i7))) {
                        ((Vec) arrayList2.get(i7)).mutableAdd(vecArr[i7]);
                        vecArr[i7].zeroOut();
                    }
                }
            });
            if (longAdder.longValue() == 0) {
                break;
            }
            for (int i6 = 0; i6 < i; i6++) {
                ((Vec) arrayList2.get(i6)).copyTo(list2.get(i6));
                list2.get(i6).mutableDivide(atomicDoubleArray.get(i6));
                if (this.dm.supportsAcceleration()) {
                    arrayList.set(i6, this.dm.getQueryInfo(list2.get(i6)));
                }
            }
        } while (longAdder.longValue() > 0);
        if (!z3) {
            return 0.0d;
        }
        if (this.saveCentroidDistance) {
            this.nearestCentroidDist = new double[dataVectors.size()];
        } else {
            this.nearestCentroidDist = null;
        }
        List<Double> list4 = accelerationCache;
        return ((Double) ParallelUtils.run(z2, sampleSize2, (i7, i8) -> {
            double d = 0.0d;
            for (int i7 = i7; i7 < i8; i7++) {
                double dist = this.dm.dist(i7, (Vec) list2.get(iArr[i7]), (List) arrayList.get(iArr[i7]), dataVectors, list4);
                d += Math.pow(dist, 2.0d);
                if (this.saveCentroidDistance) {
                    this.nearestCentroidDist[i7] = dist;
                }
            }
            return Double.valueOf(d);
        }, (d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        })).doubleValue();
    }

    @Override // jsat.clustering.kmeans.KMeans, jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public NaiveKMeans mo114clone() {
        return new NaiveKMeans(this);
    }
}
