package jsat.clustering;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicBoolean;
import jsat.distributions.empirical.kernelfunc.GaussKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.multivariate.MetricKDE;
import jsat.distributions.multivariate.MultivariateKDE;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/clustering/MeanShift.class */
public class MeanShift implements Clusterer {
    private static final long serialVersionUID = 4061491342362690455L;
    public static final int DefaultMaxIterations = 1000;
    public static final double DefaultScaleBandwidthFactor = 1.0d;
    private MultivariateKDE mkde;
    private int maxIterations;
    private double scaleBandwidthFactor;

    public MeanShift() {
        this(new EuclideanDistance());
    }

    public MeanShift(DistanceMetric distanceMetric) {
        this(new MetricKDE(GaussKF.getInstance(), distanceMetric));
    }

    public MeanShift(MultivariateKDE multivariateKDE) {
        this.maxIterations = 1000;
        this.scaleBandwidthFactor = 1.0d;
        this.mkde = multivariateKDE;
    }

    public MeanShift(MeanShift meanShift) {
        this.maxIterations = 1000;
        this.scaleBandwidthFactor = 1.0d;
        this.mkde = meanShift.mkde.clone();
        this.maxIterations = meanShift.maxIterations;
        this.scaleBandwidthFactor = meanShift.scaleBandwidthFactor;
    }

    public void setMaxIterations(int i) {
        if (i <= 0) {
            throw new ArithmeticException("Invalid iteration count, " + i);
        }
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setScaleBandwidthFactor(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("Invalid scale factor, " + d);
        }
        this.scaleBandwidthFactor = d;
    }

    public double getScaleBandwidthFactor() {
        return this.scaleBandwidthFactor;
    }

    /* JADX WARN: Code restructure failed: missing block: B:21:0x000a, code lost:
    
        if (r10.length < r8.getSampleSize()) goto L6;
     */
    @Override // jsat.clustering.Clusterer
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public int[] cluster(jsat.DataSet r8, boolean r9, int[] r10) {
        /*
            r7 = this;
            r0 = r10
            if (r0 == 0) goto Ld
            r0 = r10
            int r0 = r0.length     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r1 = r8
            int r1 = r1.getSampleSize()     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            if (r0 >= r1) goto L14
        Ld:
            r0 = r8
            int r0 = r0.getSampleSize()     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            int[] r0 = new int[r0]     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r10 = r0
        L14:
            r0 = r8
            int r0 = r0.getSampleSize()     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            boolean[] r0 = new boolean[r0]     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r11 = r0
            r0 = r11
            r1 = 0
            java.util.Arrays.fill(r0, r1)     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0 = r7
            jsat.distributions.multivariate.MultivariateKDE r0 = r0.mkde     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            jsat.distributions.empirical.kernelfunc.KernelFunction r0 = r0.getKernelFunction()     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r12 = r0
            r0 = r7
            jsat.distributions.multivariate.MultivariateKDE r0 = r0.mkde     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r1 = r8
            r2 = r9
            boolean r0 = r0.setUsingData(r1, r2)     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0 = r7
            jsat.distributions.multivariate.MultivariateKDE r0 = r0.mkde     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r1 = r7
            double r1 = r1.scaleBandwidthFactor     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0.scaleBandwidth(r1)     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0 = r11
            int r0 = r0.length     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            jsat.linear.Vec[] r0 = new jsat.linear.Vec[r0]     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r13 = r0
            r0 = 0
            r14 = r0
        L4b:
            r0 = r14
            r1 = r13
            int r1 = r1.length     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            if (r0 >= r1) goto L6a
            r0 = r13
            r1 = r14
            r2 = r8
            r3 = r14
            jsat.classifiers.DataPoint r2 = r2.getDataPoint(r3)     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            jsat.linear.Vec r2 = r2.getNumericalValues()     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            jsat.linear.Vec r2 = r2.mo46clone()     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0[r1] = r2     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            int r14 = r14 + 1
            goto L4b
        L6a:
            r0 = r7
            r1 = r11
            r2 = r13
            r3 = r10
            r4 = r12
            r5 = r9
            r0.mainLoop(r1, r2, r3, r4, r5)     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0 = r7
            r1 = r11
            r2 = r13
            r3 = r10
            r0.assignmentStep(r1, r2, r3)     // Catch: java.lang.InterruptedException -> L81 java.util.concurrent.BrokenBarrierException -> L9e
            r0 = r10
            return r0
        L81:
            r11 = move-exception
            java.lang.Class<jsat.clustering.MeanShift> r0 = jsat.clustering.MeanShift.class
            java.lang.String r0 = r0.getName()
            java.util.logging.Logger r0 = java.util.logging.Logger.getLogger(r0)
            java.util.logging.Level r1 = java.util.logging.Level.SEVERE
            r2 = 0
            r3 = r11
            r0.log(r1, r2, r3)
            jsat.exceptions.FailedToFitException r0 = new jsat.exceptions.FailedToFitException
            r1 = r0
            r2 = r11
            r1.<init>(r2)
            throw r0
        L9e:
            r11 = move-exception
            java.lang.Class<jsat.clustering.MeanShift> r0 = jsat.clustering.MeanShift.class
            java.lang.String r0 = r0.getName()
            java.util.logging.Logger r0 = java.util.logging.Logger.getLogger(r0)
            java.util.logging.Level r1 = java.util.logging.Level.SEVERE
            r2 = 0
            r3 = r11
            r0.log(r1, r2, r3)
            jsat.exceptions.FailedToFitException r0 = new jsat.exceptions.FailedToFitException
            r1 = r0
            r2 = r11
            r1.<init>(r2)
            throw r0
        */
        throw new UnsupportedOperationException("Method not decompiled: jsat.clustering.MeanShift.cluster(jsat.DataSet, boolean, int[]):int[]");
    }

    private void assignmentStep(boolean[] zArr, Vec[] vecArr, int[] iArr) {
        int i = 0;
        boolean z = true;
        while (z) {
            z = false;
            int i2 = 0;
            while (i2 < zArr.length && !zArr[i2]) {
                i2++;
            }
            for (int i3 = i2; i3 < zArr.length; i3++) {
                if (zArr[i3] && iArr[i3] != -1) {
                    z = true;
                    if (Math.abs(vecArr[i2].pNormDist(2.0d, vecArr[i3])) < 0.001d) {
                        zArr[i3] = false;
                        iArr[i3] = i;
                    }
                }
            }
            i++;
        }
    }

    private void mainLoop(boolean[] zArr, Vec[] vecArr, int[] iArr, KernelFunction kernelFunction, boolean z) throws InterruptedException, BrokenBarrierException {
        AtomicBoolean atomicBoolean = new AtomicBoolean(true);
        int i = 0;
        new CyclicBarrier(SystemInfo.LogicalCores + 1);
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new DenseVector(vecArr[0].length());
        });
        while (atomicBoolean.get()) {
            int i2 = i;
            i++;
            if (i2 >= this.maxIterations) {
                break;
            }
            atomicBoolean.set(false);
            ParallelUtils.run(z, zArr.length, i3 -> {
                if (zArr[i3]) {
                    return;
                }
                atomicBoolean.lazySet(true);
                convergenceStep(vecArr, i3, zArr, iArr, (Vec) withInitial.get(), kernelFunction);
            });
        }
        Arrays.fill(zArr, true);
    }

    private void convergenceStep(Vec[] vecArr, int i, boolean[] zArr, int[] iArr, Vec vec, KernelFunction kernelFunction) {
        double d = 0.0d;
        Vec vec2 = vecArr[i];
        List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> nearbyRaw = this.mkde.getNearbyRaw(vec2);
        if (nearbyRaw.size() == 1) {
            zArr[i] = true;
            iArr[i] = -1;
            return;
        }
        vec.zeroOut();
        for (VecPaired<VecPaired<Vec, Integer>, Double> vecPaired : nearbyRaw) {
            double d2 = -kernelFunction.kPrime(vecPaired.getPair().doubleValue());
            d += d2;
            vec.mutableAdd(d2, vecPaired);
        }
        vec.mutableDivide(d);
        if (Math.abs(vec.pNormDist(2.0d, vec2)) < 1.0E-5d) {
            zArr[i] = true;
        }
        vec.copyTo(vec2);
    }

    @Override // jsat.clustering.Clusterer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public MeanShift mo114clone() {
        return new MeanShift(this);
    }
}
