package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.DoubleStream;
import jsat.DataSet;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/TRIKMEDS.class */
public class TRIKMEDS extends PAM {
    public TRIKMEDS(DistanceMetric distanceMetric, Random random, SeedSelectionMethods.SeedSelection seedSelection) {
        super(distanceMetric, random, seedSelection);
    }

    public TRIKMEDS(DistanceMetric distanceMetric, Random random) {
        super(distanceMetric, random);
    }

    public TRIKMEDS(DistanceMetric distanceMetric) {
        super(distanceMetric);
    }

    public TRIKMEDS() {
    }

    @Override // jsat.clustering.PAM
    public void setDistanceMetric(DistanceMetric distanceMetric) {
        if (!distanceMetric.isValidMetric()) {
            throw new IllegalArgumentException("TRIKMEDS requires a valid distance metric, but " + distanceMetric.toString() + " does not obey all distance metric properties");
        }
        super.setDistanceMetric(distanceMetric);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // jsat.clustering.PAM
    public double cluster(DataSet dataSet, boolean z, int[] iArr, int[] iArr2, List<Double> list, boolean z2) {
        List<Double> list2;
        int i;
        LongAdder longAdder = new LongAdder();
        Arrays.fill(iArr2, -1);
        List<Vec> dataVectors = dataSet.getDataVectors();
        if (z) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, dataSet);
            list2 = this.dm.getAccelerationCache(dataVectors);
            SeedSelectionMethods.selectIntialPoints(dataSet, iArr, this.dm, list2, this.rand, this.seedSelection);
        } else {
            list2 = list;
        }
        int sampleSize = dataSet.getSampleSize();
        int length = iArr.length;
        AtomicIntegerArray atomicIntegerArray = new AtomicIntegerArray(length);
        double[] dArr = new double[sampleSize];
        double[] dArr2 = new double[sampleSize];
        AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(length);
        double[][] dArr3 = new double[sampleSize][length];
        AtomicDoubleArray atomicDoubleArray2 = new AtomicDoubleArray(sampleSize);
        double[] dArr4 = new double[length];
        AtomicDoubleArray atomicDoubleArray3 = new AtomicDoubleArray(length);
        ArrayList arrayList = new ArrayList(length);
        for (int i2 = 0; i2 < length; i2++) {
            arrayList.add(new ConcurrentSkipListSet());
        }
        AtomicDoubleArray atomicDoubleArray4 = new AtomicDoubleArray(length);
        AtomicDoubleArray atomicDoubleArray5 = new AtomicDoubleArray(length);
        AtomicDoubleArray atomicDoubleArray6 = new AtomicDoubleArray(length);
        AtomicDoubleArray atomicDoubleArray7 = new AtomicDoubleArray(length);
        for (int i3 = 0; i3 < length; i3++) {
            atomicIntegerArray.set(i3, iArr[i3]);
        }
        List<Double> list3 = list2;
        ParallelUtils.run(z2, sampleSize, (i4, i5) -> {
            for (int i4 = i4; i4 < i5; i4++) {
                double d = Double.POSITIVE_INFINITY;
                int i5 = 0;
                for (int i6 = 0; i6 < length; i6++) {
                    dArr3[i4][i6] = this.dm.dist(i4, atomicIntegerArray.get(i6), (List<? extends Vec>) dataVectors, (List<Double>) list3);
                    if (dArr3[i4][i6] <= d) {
                        d = dArr3[i4][i6];
                        i5 = i6;
                    }
                }
                iArr2[i4] = i5;
                dArr[i4] = d;
                atomicDoubleArray.getAndAdd(iArr2[i4], 1.0d);
                ((Set) arrayList.get(i5)).add(Integer.valueOf(i4));
                atomicDoubleArray3.addAndGet(iArr2[i4], dArr[i4]);
                atomicDoubleArray2.set(i4, 0.0d);
            }
        });
        for (int i6 = 0; i6 < length; i6++) {
            atomicDoubleArray2.set(atomicIntegerArray.get(i6), atomicDoubleArray3.get(i6));
        }
        int i7 = 0;
        do {
            longAdder.reset();
            boolean[] zArr = new boolean[length];
            Arrays.fill(zArr, false);
            List<Double> list4 = list2;
            ParallelUtils.run(z2, sampleSize, i8 -> {
                for (int i8 = 0; i8 < length; i8++) {
                    if (atomicDoubleArray2.get(i8) < atomicDoubleArray3.get(i8)) {
                        double d = 0.0d;
                        Iterator it = ((Set) arrayList.get(i8)).iterator();
                        while (it.hasNext()) {
                            int intValue = ((Integer) it.next()).intValue();
                            dArr2[intValue] = this.dm.dist(i8, intValue, (List<? extends Vec>) dataVectors, (List<Double>) list4);
                            d += dArr2[intValue];
                        }
                        atomicDoubleArray2.set(i8, d);
                        if (d < atomicDoubleArray3.get(i8)) {
                            synchronized (atomicDoubleArray3) {
                                if (d < atomicDoubleArray3.get(i8)) {
                                    atomicDoubleArray3.set(i8, d);
                                    atomicIntegerArray.set(i8, i8);
                                    zArr[i8] = true;
                                    Iterator it2 = ((Set) arrayList.get(i8)).iterator();
                                    while (it2.hasNext()) {
                                        int intValue2 = ((Integer) it2.next()).intValue();
                                        dArr[intValue2] = dArr2[intValue2];
                                    }
                                }
                            }
                        }
                        Iterator it3 = ((Set) arrayList.get(i8)).iterator();
                        while (it3.hasNext()) {
                            int intValue3 = ((Integer) it3.next()).intValue();
                            atomicDoubleArray2.accumulateAndGet(intValue3, dArr[intValue3] * atomicDoubleArray.get(i8), (d2, d3) -> {
                                return Math.max(d2, Math.abs(d3 - d2));
                            });
                        }
                    }
                }
            });
            List<Double> list5 = list2;
            ParallelUtils.run(z2, length, i9 -> {
                if (zArr[i9]) {
                    dArr4[i9] = this.dm.dist(iArr[i9], atomicIntegerArray.get(i9), (List<? extends Vec>) dataVectors, (List<Double>) list5);
                    iArr[i9] = atomicIntegerArray.get(i9);
                }
                atomicDoubleArray4.set(i9, 0.0d);
                atomicDoubleArray5.set(i9, 0.0d);
                atomicDoubleArray6.set(i9, 0.0d);
                atomicDoubleArray7.set(i9, 0.0d);
            });
            List<Double> list6 = list2;
            ParallelUtils.run(z2, sampleSize, i10 -> {
                for (int i10 = 0; i10 < length; i10++) {
                    double[] dArr5 = dArr3[i10];
                    int i11 = i10;
                    dArr5[i11] = dArr5[i11] - dArr4[i10];
                }
                dArr3[i10][iArr2[i10]] = dArr[i10];
                int i12 = iArr2[i10];
                double d = dArr[i10];
                for (int i13 = 0; i13 < length; i13++) {
                    if (dArr3[i10][i13] < dArr[i10]) {
                        dArr3[i10][i13] = this.dm.dist(i10, iArr[i13], (List<? extends Vec>) dataVectors, (List<Double>) list6);
                        if (dArr3[i10][i13] < dArr[i10]) {
                            iArr2[i10] = i13;
                            dArr[i10] = dArr3[i10][i13];
                        }
                    }
                }
                if (i12 != iArr2[i10]) {
                    atomicDoubleArray.getAndDecrement(i12);
                    atomicDoubleArray.getAndIncrement(iArr2[i10]);
                    longAdder.increment();
                    ((Set) arrayList.get(i12)).remove(Integer.valueOf(i10));
                    ((Set) arrayList.get(iArr2[i10])).add(Integer.valueOf(i10));
                    atomicDoubleArray2.set(i10, 0.0d);
                    atomicDoubleArray4.getAndIncrement(iArr2[i10]);
                    atomicDoubleArray5.getAndIncrement(i12);
                    atomicDoubleArray6.getAndAdd(iArr2[i10], dArr[i10]);
                    atomicDoubleArray6.getAndAdd(i12, d);
                }
            });
            double[] dArr5 = new double[length];
            double[] dArr6 = new double[length];
            double[] dArr7 = new double[length];
            double[] dArr8 = new double[length];
            for (int i11 = 0; i11 < length; i11++) {
                dArr5[i11] = atomicDoubleArray6.get(i11) + atomicDoubleArray7.get(i11);
                dArr6[i11] = atomicDoubleArray6.get(i11) - atomicDoubleArray7.get(i11);
                dArr7[i11] = atomicDoubleArray4.get(i11) + atomicDoubleArray5.get(i11);
                dArr8[i11] = atomicDoubleArray4.get(i11) - atomicDoubleArray5.get(i11);
            }
            ParallelUtils.run(z2, sampleSize, (i12, i13) -> {
                for (int i12 = i12; i12 < i13; i12++) {
                    double d = 0.0d;
                    for (int i13 = 0; i13 < length; i13++) {
                        d -= Math.min(dArr5[i13] - (dArr8[i13] * dArr[i12]), (dArr7[i13] * dArr[i12]) - dArr6[i13]);
                    }
                    atomicDoubleArray2.getAndAdd(i12, d);
                }
            });
            if (longAdder.sum() <= 0) {
                break;
            }
            i = i7;
            i7++;
        } while (i < this.iterLimit);
        return ParallelUtils.streamP(DoubleStream.of(dArr), z2).map(d -> {
            return d * d;
        }).sum();
    }

    public static int medoid(boolean z, List<? extends Vec> list, DistanceMetric distanceMetric) {
        IntList intList = new IntList(list.size());
        ListUtils.addRange(intList, 0, list.size(), 1);
        return medoid(z, intList, list, distanceMetric, distanceMetric.getAccelerationCache(list, z));
    }

    public static int medoid(boolean z, Collection<Integer> collection, List<? extends Vec> list, DistanceMetric distanceMetric, List<Double> list2) {
        int size = list.size();
        AtomicDoubleArray atomicDoubleArray = new AtomicDoubleArray(size);
        AtomicDouble atomicDouble = new AtomicDouble(Double.POSITIVE_INFINITY);
        IntList intList = new IntList(collection);
        Collections.shuffle(intList, RandomUtil.getRandom());
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new double[size];
        });
        ParallelUtils.streamP(intList.streamInts(), z).forEach(i -> {
            double[] dArr = (double[]) withInitial.get();
            double d = 0.0d;
            if (atomicDoubleArray.get(i) < atomicDouble.get()) {
                Iterator it = collection.iterator();
                while (it.hasNext()) {
                    int intValue = ((Integer) it.next()).intValue();
                    double dist = distanceMetric.dist(i, intValue, (List<? extends Vec>) list, (List<Double>) list2);
                    dArr[intValue] = dist;
                    d += dist;
                }
                double size2 = d / (collection.size() - 1);
                atomicDoubleArray.set(i, size2);
                if (size2 < atomicDouble.get()) {
                    atomicDouble.getAndUpdate(d2 -> {
                        return Math.min(d2, size2);
                    });
                }
                Iterator it2 = collection.iterator();
                while (it2.hasNext()) {
                    int intValue2 = ((Integer) it2.next()).intValue();
                    atomicDoubleArray.getAndUpdate(intValue2, d3 -> {
                        return Math.max(d3, Math.abs(size2 - dArr[intValue2]));
                    });
                }
            }
        });
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (atomicDoubleArray.get(intValue) == atomicDouble.get()) {
                return intValue;
            }
        }
        return -1;
    }
}
