package smile.clustering;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.util.MulticoreExecutor;

/* loaded from: input_file:smile/clustering/KMeans.class */
public class KMeans extends PartitionClustering<double[]> implements Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger(KMeans.class);
    double distortion;
    double[][] centroids;

    /* loaded from: input_file:smile/clustering/KMeans$KMeansThread.class */
    static class KMeansThread implements Callable<KMeans> {
        final BBDTree bbd;
        final double[][] data;
        final int k;
        final int maxIter;

        KMeansThread(BBDTree bBDTree, double[][] dArr, int i, int i2) {
            this.bbd = bBDTree;
            this.data = dArr;
            this.k = i;
            this.maxIter = i2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public KMeans call() {
            return new KMeans(this.bbd, this.data, this.k, this.maxIter);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/clustering/KMeans$LloydThread.class */
    public static class LloydThread implements Callable<Double> {
        final int start;
        final int end;
        final double[][] data;
        final int k;
        final double[][] centroids;
        int[] y;

        LloydThread(double[][] dArr, double[][] dArr2, int[] iArr, int i, int i2) {
            this.data = dArr;
            this.k = dArr2.length;
            this.y = iArr;
            this.centroids = dArr2;
            this.start = i;
            this.end = i2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Double call() {
            double d = 0.0d;
            for (int i = this.start; i < this.end; i++) {
                double d2 = Double.MAX_VALUE;
                for (int i2 = 0; i2 < this.k; i2++) {
                    double squaredDistance = PartitionClustering.squaredDistance(this.data[i], this.centroids[i2]);
                    if (d2 > squaredDistance) {
                        this.y[i] = i2;
                        d2 = squaredDistance;
                    }
                }
                d += d2;
            }
            return Double.valueOf(d);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KMeans() {
    }

    public double distortion() {
        return this.distortion;
    }

    public double[][] centroids() {
        return this.centroids;
    }

    @Override // smile.clustering.Clustering
    public int predict(double[] dArr) {
        double d = Double.MAX_VALUE;
        int i = 0;
        for (int i2 = 0; i2 < this.k; i2++) {
            double squaredDistance = Math.squaredDistance(dArr, this.centroids[i2]);
            if (squaredDistance < d) {
                d = squaredDistance;
                i = i2;
            }
        }
        return i;
    }

    public KMeans(double[][] dArr, int i) {
        this(dArr, i, 100);
    }

    public KMeans(double[][] dArr, int i, int i2) {
        this(new BBDTree(dArr), dArr, i, i2);
    }

    KMeans(BBDTree bBDTree, double[][] dArr, int i, int i2) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        this.k = i;
        this.distortion = Double.MAX_VALUE;
        this.y = seed(dArr, i, ClusteringDistance.EUCLIDEAN);
        this.size = new int[i];
        this.centroids = new double[i][length2];
        for (int i3 = 0; i3 < length; i3++) {
            int[] iArr = this.size;
            int i4 = this.y[i3];
            iArr[i4] = iArr[i4] + 1;
        }
        for (int i5 = 0; i5 < length; i5++) {
            for (int i6 = 0; i6 < length2; i6++) {
                double[] dArr2 = this.centroids[this.y[i5]];
                int i7 = i6;
                dArr2[i7] = dArr2[i7] + dArr[i5][i6];
            }
        }
        for (int i8 = 0; i8 < i; i8++) {
            for (int i9 = 0; i9 < length2; i9++) {
                double[] dArr3 = this.centroids[i8];
                int i10 = i9;
                dArr3[i10] = dArr3[i10] / this.size[i8];
            }
        }
        double[][] dArr4 = new double[i][length2];
        for (int i11 = 1; i11 <= i2; i11++) {
            double clustering = bBDTree.clustering(this.centroids, dArr4, this.size, this.y);
            for (int i12 = 0; i12 < i; i12++) {
                if (this.size[i12] > 0) {
                    for (int i13 = 0; i13 < length2; i13++) {
                        this.centroids[i12][i13] = dArr4[i12][i13] / this.size[i12];
                    }
                }
            }
            if (this.distortion <= clustering) {
                return;
            }
            this.distortion = clustering;
        }
    }

    public KMeans(double[][] dArr, int i, int i2, int i3) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + i3);
        }
        BBDTree bBDTree = new BBDTree(dArr);
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < i3; i4++) {
            arrayList.add(new KMeansThread(bBDTree, dArr, i, i2));
        }
        KMeans kMeans = new KMeans();
        kMeans.distortion = Double.MAX_VALUE;
        try {
            for (KMeans kMeans2 : MulticoreExecutor.run(arrayList)) {
                if (kMeans2.distortion < kMeans.distortion) {
                    kMeans = kMeans2;
                }
            }
        } catch (Exception e) {
            logger.error("Failed to run K-Means on multi-core", e);
            for (int i5 = 0; i5 < i3; i5++) {
                KMeans lloyd = lloyd(dArr, i, i2);
                if (lloyd.distortion < kMeans.distortion) {
                    kMeans = lloyd;
                }
            }
        }
        this.k = kMeans.k;
        this.distortion = kMeans.distortion;
        this.centroids = kMeans.centroids;
        this.y = kMeans.y;
        this.size = kMeans.size;
    }

    public static KMeans lloyd(double[][] dArr, int i) {
        return lloyd(dArr, i, 100);
    }

    public static KMeans lloyd(double[][] dArr, int i, int i2) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        int[][] iArr = new int[i][length2];
        double d = Double.MAX_VALUE;
        int[] iArr2 = new int[i];
        double[][] dArr2 = new double[i][length2];
        int[] seed = seed(dArr, i, ClusteringDistance.EUCLIDEAN_MISSING_VALUES);
        int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
        ArrayList arrayList = null;
        if (length >= 1000 && threadPoolSize >= 2) {
            arrayList = new ArrayList(threadPoolSize + 1);
            int i3 = length / threadPoolSize;
            if (i3 < 100) {
                i3 = 100;
            }
            int i4 = 0;
            int i5 = i3;
            for (int i6 = 0; i6 < threadPoolSize - 1; i6++) {
                arrayList.add(new LloydThread(dArr, dArr2, seed, i4, i5));
                i4 += i3;
                i5 += i3;
            }
            arrayList.add(new LloydThread(dArr, dArr2, seed, i4, length));
        }
        for (int i7 = 0; i7 < i2; i7++) {
            Arrays.fill(iArr2, 0);
            for (int i8 = 0; i8 < i; i8++) {
                Arrays.fill(dArr2[i8], 0.0d);
                Arrays.fill(iArr[i8], 0);
            }
            for (int i9 = 0; i9 < length; i9++) {
                int i10 = seed[i9];
                iArr2[i10] = iArr2[i10] + 1;
                for (int i11 = 0; i11 < length2; i11++) {
                    if (!Double.isNaN(dArr[i9][i11])) {
                        double[] dArr3 = dArr2[i10];
                        int i12 = i11;
                        dArr3[i12] = dArr3[i12] + dArr[i9][i11];
                        int[] iArr3 = iArr[i10];
                        int i13 = i11;
                        iArr3[i13] = iArr3[i13] + 1;
                    }
                }
            }
            for (int i14 = 0; i14 < i; i14++) {
                for (int i15 = 0; i15 < length2; i15++) {
                    double[] dArr4 = dArr2[i14];
                    int i16 = i15;
                    dArr4[i16] = dArr4[i16] / iArr[i14][i15];
                }
            }
            double d2 = Double.NaN;
            if (arrayList != null) {
                try {
                    d2 = 0.0d;
                    Iterator it = MulticoreExecutor.run(arrayList).iterator();
                    while (it.hasNext()) {
                        d2 += ((Double) it.next()).doubleValue();
                    }
                } catch (Exception e) {
                    logger.error("Failed to run K-Means on multi-core", e);
                    d2 = Double.NaN;
                }
            }
            if (Double.isNaN(d2)) {
                d2 = 0.0d;
                for (int i17 = 0; i17 < length; i17++) {
                    double d3 = Double.MAX_VALUE;
                    for (int i18 = 0; i18 < i; i18++) {
                        double squaredDistance = squaredDistance(dArr[i17], dArr2[i18]);
                        if (d3 > squaredDistance) {
                            seed[i17] = i18;
                            d3 = squaredDistance;
                        }
                    }
                    d2 += d3;
                }
            }
            if (d <= d2) {
                break;
            }
            d = d2;
        }
        Arrays.fill(iArr2, 0);
        for (int i19 = 0; i19 < i; i19++) {
            Arrays.fill(dArr2[i19], 0.0d);
            Arrays.fill(iArr[i19], 0);
        }
        for (int i20 = 0; i20 < length; i20++) {
            int i21 = seed[i20];
            iArr2[i21] = iArr2[i21] + 1;
            for (int i22 = 0; i22 < length2; i22++) {
                if (!Double.isNaN(dArr[i20][i22])) {
                    double[] dArr5 = dArr2[i21];
                    int i23 = i22;
                    dArr5[i23] = dArr5[i23] + dArr[i20][i22];
                    int[] iArr4 = iArr[i21];
                    int i24 = i22;
                    iArr4[i24] = iArr4[i24] + 1;
                }
            }
        }
        for (int i25 = 0; i25 < i; i25++) {
            for (int i26 = 0; i26 < length2; i26++) {
                double[] dArr6 = dArr2[i25];
                int i27 = i26;
                dArr6[i27] = dArr6[i27] / iArr[i25][i26];
            }
        }
        KMeans kMeans = new KMeans();
        kMeans.k = i;
        kMeans.distortion = d;
        kMeans.size = iArr2;
        kMeans.centroids = dArr2;
        kMeans.y = seed;
        return kMeans;
    }

    public static KMeans lloyd(double[][] dArr, int i, int i2, int i3) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        if (i3 <= 0) {
            throw new IllegalArgumentException("Invalid number of runs: " + i3);
        }
        KMeans lloyd = lloyd(dArr, i, i2);
        for (int i4 = 1; i4 < i3; i4++) {
            KMeans lloyd2 = lloyd(dArr, i, i2);
            if (lloyd2.distortion < lloyd.distortion) {
                lloyd = lloyd2;
            }
        }
        return lloyd;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("K-Means distortion: %.5f%n", Double.valueOf(this.distortion)));
        sb.append(String.format("Clusters of %d data points of dimension %d:%n", Integer.valueOf(this.y.length), Integer.valueOf(this.centroids[0].length)));
        for (int i = 0; i < this.k; i++) {
            int round = (int) Math.round((1000.0d * this.size[i]) / this.y.length);
            sb.append(String.format("%3d\t%5d (%2d.%1d%%)%n", Integer.valueOf(i), Integer.valueOf(this.size[i]), Integer.valueOf(round / 10), Integer.valueOf(round % 10)));
        }
        return sb.toString();
    }
}
