package jsat.clustering;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.math.OnLineStatistics;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/MEDDIT.class */
public class MEDDIT extends PAM {
    private double tolerance;

    public MEDDIT(DistanceMetric distanceMetric, Random random, SeedSelectionMethods.SeedSelection seedSelection) {
        super(distanceMetric, random, seedSelection);
        this.tolerance = 0.01d;
    }

    public MEDDIT(DistanceMetric distanceMetric, Random random) {
        super(distanceMetric, random);
        this.tolerance = 0.01d;
    }

    public MEDDIT(DistanceMetric distanceMetric) {
        super(distanceMetric);
        this.tolerance = 0.01d;
    }

    public MEDDIT() {
        this.tolerance = 0.01d;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.PAM
    public double cluster(DataSet dataSet, boolean z, int[] iArr, int[] iArr2, List<Double> list, boolean z2) {
        List<Double> list2;
        int i;
        DoubleAdder doubleAdder = new DoubleAdder();
        LongAdder longAdder = new LongAdder();
        Arrays.fill(iArr2, -1);
        List<Vec> dataVectors = dataSet.getDataVectors();
        int sampleSize = dataSet.getSampleSize();
        if (z) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
            list2 = this.dm.getAccelerationCache(dataVectors);
            SeedSelectionMethods.selectIntialPoints(dataSet, iArr, this.dm, list2, this.rand, this.seedSelection);
        } else {
            list2 = list;
        }
        double sampleSize2 = this.tolerance < 0.0d ? 1.0d / dataSet.getSampleSize() : this.tolerance;
        int i2 = 0;
        do {
            longAdder.reset();
            doubleAdder.reset();
            List<Double> list3 = list2;
            ParallelUtils.run(z2, sampleSize, (i3, i4) -> {
                for (int i3 = i3; i3 < i4; i3++) {
                    int i4 = 0;
                    double dist = this.dm.dist(iArr[0], i3, (List<? extends Vec>) dataVectors, (List<Double>) list3);
                    for (int i5 = 1; i5 < iArr.length; i5++) {
                        double dist2 = this.dm.dist(iArr[i5], i3, (List<? extends Vec>) dataVectors, (List<Double>) list3);
                        if (dist2 < dist) {
                            dist = dist2;
                            i4 = i5;
                        }
                    }
                    if (iArr2[i3] != i4) {
                        longAdder.increment();
                        iArr2[i3] = i4;
                    }
                    doubleAdder.add(dist * dist);
                }
            });
            IntList intList = new IntList(sampleSize);
            for (int i5 = 0; i5 < iArr.length; i5++) {
                intList.clear();
                for (int i6 = 0; i6 < sampleSize; i6++) {
                    if (iArr2[i6] == i5) {
                        intList.add(i6);
                    }
                }
                if (!intList.isEmpty()) {
                    iArr[i5] = medoid(z2, intList, sampleSize2, dataVectors, this.dm, list2);
                }
            }
            if (longAdder.sum() <= 0) {
                break;
            }
            i = i2;
            i2++;
        } while (i < this.iterLimit);
        return doubleAdder.sum();
    }

    public static int medoid(boolean z, List<? extends Vec> list, DistanceMetric distanceMetric) {
        return medoid(z, list, 1.0d / list.size(), distanceMetric);
    }

    public static int medoid(boolean z, List<? extends Vec> list, double d, DistanceMetric distanceMetric) {
        IntList intList = new IntList(list.size());
        ListUtils.addRange(intList, 0, list.size(), 1);
        return medoid(z, intList, d, list, distanceMetric, distanceMetric.getAccelerationCache(list, z));
    }

    public static int medoid(boolean z, Collection<Integer> collection, List<? extends Vec> list, DistanceMetric distanceMetric, List<Double> list2) {
        return medoid(z, collection, 1.0d / collection.size(), list, distanceMetric, list2);
    }

    public static int medoid(boolean z, Collection<Integer> collection, double d, List<? extends Vec> list, DistanceMetric distanceMetric, List<Double> list2) {
        int min;
        int min2;
        int size = collection.size();
        if (d <= 0.0d || size < SystemInfo.LogicalCores) {
            return PAM.medoid(z, collection, list, distanceMetric, list2);
        }
        double log = Math.log(1.0d) - Math.log(d);
        AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(size);
        AtomicIntegerArray atomicIntegerArray = new AtomicIntegerArray(size);
        int[] array = collection.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
        boolean isSymmetric = distanceMetric.isSymmetric();
        double[] dArr = new double[size];
        double[] dArr2 = new double[size];
        ThreadLocal withInitial = ThreadLocal.withInitial(RandomUtil::getRandom);
        OnLineStatistics onLineStatistics = (OnLineStatistics) ParallelUtils.run(z, size, (i, i2) -> {
            int i;
            Random random = (Random) withInitial.get();
            OnLineStatistics onLineStatistics2 = new OnLineStatistics();
            for (int i2 = i; i2 < i2; i2++) {
                int nextInt = random.nextInt(size);
                while (true) {
                    i = nextInt;
                    if (i != i2) {
                        break;
                    }
                    nextInt = random.nextInt(size);
                }
                double dist = distanceMetric.dist(array[i2], array[i], (List<? extends Vec>) list, (List<Double>) list2);
                onLineStatistics2.add(dist);
                atomicDoubleArray.addAndGet(i2, dist);
                atomicIntegerArray.incrementAndGet(i2);
                if (isSymmetric) {
                    atomicDoubleArray.addAndGet(i, dist);
                    atomicIntegerArray.incrementAndGet(i);
                }
            }
            return onLineStatistics2;
        }, (onLineStatistics2, onLineStatistics3) -> {
            return OnLineStatistics.add(onLineStatistics2, onLineStatistics3);
        });
        ConcurrentSkipListSet concurrentSkipListSet = new ConcurrentSkipListSet((num2, num3) -> {
            int compare = Double.compare(dArr[num2.intValue()], dArr[num3.intValue()]);
            if (compare == 0) {
                compare = num2.compareTo(num3);
            }
            return compare;
        });
        ConcurrentSkipListSet concurrentSkipListSet2 = new ConcurrentSkipListSet((num4, num5) -> {
            int compare = Double.compare(dArr2[num4.intValue()], dArr2[num5.intValue()]);
            if (compare == 0) {
                compare = num4.compareTo(num5);
            }
            return compare;
        });
        ParallelUtils.run(z, size, (i3, i4) -> {
            double varance = onLineStatistics.getVarance();
            for (int i3 = i3; i3 < i4; i3++) {
                int i4 = atomicIntegerArray.get(i3);
                double sqrt = Math.sqrt(((2.0d * varance) * log) / i4);
                dArr[i3] = (atomicDoubleArray.get(i3) / i4) - sqrt;
                dArr2[i3] = (atomicDoubleArray.get(i3) / i4) + sqrt;
                concurrentSkipListSet.add(Integer.valueOf(i3));
                concurrentSkipListSet2.add(Integer.valueOf(i3));
            }
        });
        if (z) {
            min = Math.max(SystemInfo.LogicalCores, 32);
            min2 = Math.min(32, size - 1);
        } else {
            min = Math.min(32, size);
            min2 = Math.min(32, size - 1);
        }
        IntList intList = new IntList();
        IntList intList2 = new IntList();
        boolean[] zArr = new boolean[size];
        Arrays.fill(zArr, false);
        int i5 = 0;
        while (i5 < size) {
            intList.clear();
            intList2.clear();
            if (dArr2[((Integer) concurrentSkipListSet2.first()).intValue()] < dArr[((Integer) concurrentSkipListSet.first()).intValue()]) {
                return array[((Integer) concurrentSkipListSet2.first()).intValue()];
            }
            while (intList.size() < min && !concurrentSkipListSet.isEmpty()) {
                int intValue = ((Integer) concurrentSkipListSet.pollFirst()).intValue();
                if (atomicIntegerArray.get(intValue) >= size - 1 && !zArr[intValue]) {
                    double doubleValue = ((Double) ParallelUtils.run(z, size, (i6, i7) -> {
                        double d2 = 0.0d;
                        for (int i6 = i6; i6 < i7; i6++) {
                            if (intValue != i6) {
                                d2 += distanceMetric.dist(array[intValue], array[i6], (List<? extends Vec>) list, (List<Double>) list2);
                            }
                        }
                        return Double.valueOf(d2);
                    }, (d2, d3) -> {
                        return Double.valueOf(d2.doubleValue() + d3.doubleValue());
                    })).doubleValue() / (size - 1);
                    concurrentSkipListSet2.remove(Integer.valueOf(intValue));
                    dArr2[intValue] = doubleValue;
                    dArr[intValue] = doubleValue;
                    atomicDoubleArray.set(intValue, doubleValue);
                    atomicIntegerArray.set(intValue, size);
                    zArr[intValue] = true;
                    i5++;
                    intList2.add(intValue);
                }
                if (!zArr[intValue]) {
                    intList.add(intValue);
                }
            }
            int i8 = min2;
            OnLineStatistics onLineStatistics4 = (OnLineStatistics) ParallelUtils.run(z, intList.size(), (i9, i10) -> {
                int i9;
                Random random = (Random) withInitial.get();
                OnLineStatistics onLineStatistics5 = new OnLineStatistics();
                for (int i10 = i9; i10 < i10; i10++) {
                    int intValue2 = intList.get(i10).intValue();
                    for (int i11 = 0; i11 < i8; i11++) {
                        int nextInt = random.nextInt(size);
                        while (true) {
                            i9 = nextInt;
                            if (i9 != intValue2) {
                                break;
                            }
                            nextInt = random.nextInt(size);
                        }
                        double dist = distanceMetric.dist(array[intValue2], array[i9], (List<? extends Vec>) list, (List<Double>) list2);
                        onLineStatistics5.add(dist);
                        atomicDoubleArray.addAndGet(intValue2, dist);
                        atomicIntegerArray.incrementAndGet(intValue2);
                        if (isSymmetric && !zArr[i9]) {
                            atomicDoubleArray.addAndGet(i9, dist);
                            atomicIntegerArray.incrementAndGet(i9);
                        }
                    }
                }
                return onLineStatistics5;
            }, (onLineStatistics5, onLineStatistics6) -> {
                return OnLineStatistics.add(onLineStatistics5, onLineStatistics6);
            });
            if (!intList.isEmpty()) {
                onLineStatistics.add(onLineStatistics4);
            }
            double varance = onLineStatistics.getVarance();
            concurrentSkipListSet.addAll(intList2);
            concurrentSkipListSet2.addAll(intList2);
            concurrentSkipListSet2.removeAll(intList);
            Iterator<Integer> it = intList.iterator();
            while (it.hasNext()) {
                int intValue2 = it.next().intValue();
                int i11 = atomicIntegerArray.get(intValue2);
                double sqrt = Math.sqrt(((2.0d * varance) * log) / i11);
                dArr[intValue2] = (atomicDoubleArray.get(intValue2) / i11) - sqrt;
                dArr2[intValue2] = (atomicDoubleArray.get(intValue2) / i11) + sqrt;
                concurrentSkipListSet.add(Integer.valueOf(intValue2));
                concurrentSkipListSet2.add(Integer.valueOf(intValue2));
            }
        }
        int i12 = 0;
        for (int i13 = 1; i13 < size; i13++) {
            if (dArr[i13] < dArr[i12]) {
                i12 = i13;
            }
        }
        return i12;
    }
}
