package jsat.clustering;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.distributions.multivariate.MultivariateDistribution;
import jsat.distributions.multivariate.NormalM;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/EMGaussianMixture.class */
public class EMGaussianMixture implements KClusterer, MultivariateDistribution {
    private SeedSelectionMethods.SeedSelection seedSelection;
    private static final long serialVersionUID = 2606159815670221662L;
    private List<NormalM> gaussians;
    private double[] a_k;
    private double tolerance;
    protected int MaxIterLimit;

    public EMGaussianMixture(SeedSelectionMethods.SeedSelection seedSelection) {
        this.tolerance = 0.001d;
        this.MaxIterLimit = Integer.MAX_VALUE;
        setSeedSelection(seedSelection);
    }

    public EMGaussianMixture() {
        this(SeedSelectionMethods.SeedSelection.KPP);
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    public void setIterationLimit(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Iterations must be a positive value, not " + i);
        }
        this.MaxIterLimit = i;
    }

    public int getIterationLimit() {
        return this.MaxIterLimit;
    }

    public EMGaussianMixture(EMGaussianMixture eMGaussianMixture) {
        this.tolerance = 0.001d;
        this.MaxIterLimit = Integer.MAX_VALUE;
        if (eMGaussianMixture.gaussians != null && !eMGaussianMixture.gaussians.isEmpty()) {
            this.gaussians = new ArrayList(eMGaussianMixture.gaussians.size());
            Iterator<NormalM> it = eMGaussianMixture.gaussians.iterator();
            while (it.hasNext()) {
                this.gaussians.add(it.next().clone());
            }
        }
        if (eMGaussianMixture.a_k != null) {
            this.a_k = Arrays.copyOf(eMGaussianMixture.a_k, eMGaussianMixture.a_k.length);
        }
        this.MaxIterLimit = eMGaussianMixture.MaxIterLimit;
        this.tolerance = eMGaussianMixture.tolerance;
    }

    private EMGaussianMixture(List<NormalM> list, double[] dArr, double d) {
        this.tolerance = 0.001d;
        this.MaxIterLimit = Integer.MAX_VALUE;
        this.gaussians = new ArrayList(dArr.length);
        this.a_k = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            this.gaussians.add(list.get(i).clone());
            this.a_k[i] = dArr[i];
        }
    }

    protected double cluster(DataSet dataSet, List<Double> list, int i, List<Vec> list2, int[] iArr, boolean z, boolean z2, boolean z3) {
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        ArrayList arrayList = new ArrayList();
        if (list2.size() < i) {
            list2.clear();
            list2.addAll(SeedSelectionMethods.selectIntialPoints(dataSet, i, euclideanDistance, list, RandomUtil.getRandom(), this.seedSelection, z2));
            Iterator<Vec> it = list2.iterator();
            while (it.hasNext()) {
                arrayList.add(euclideanDistance.getQueryInfo(it.next()));
            }
        }
        ArrayList arrayList2 = new ArrayList(i);
        int numNumericalVars = dataSet.getNumNumericalVars();
        for (int i2 = 0; i2 < list2.size(); i2++) {
            arrayList2.add(new DenseMatrix(numNumericalVars, numNumericalVars));
        }
        this.a_k = new double[i];
        double sampleSize = dataSet.getSampleSize();
        DenseVector denseVector = new DenseVector(numNumericalVars);
        List<Vec> dataVectors = dataSet.getDataVectors();
        for (int i3 = 0; i3 < dataSet.getSampleSize(); i3++) {
            Vec numericalValues = dataSet.getDataPoint(i3).getNumericalValues();
            double dist = euclideanDistance.dist(i3, list2.get(0), (List) arrayList.get(0), dataVectors, list);
            int i4 = 0;
            for (int i5 = 1; i5 < i; i5++) {
                double dist2 = euclideanDistance.dist(i3, list2.get(i5), (List) arrayList.get(i5), dataVectors, list);
                if (dist2 < dist) {
                    dist = dist2;
                    i4 = i5;
                }
            }
            iArr[i3] = i4;
            double[] dArr = this.a_k;
            int i6 = i4;
            dArr[i6] = dArr[i6] + 1.0d;
            numericalValues.copyTo(denseVector);
            denseVector.mutableSubtract(list2.get(i4));
            Matrix.OuterProductUpdate(arrayList2.get(i4), denseVector, denseVector, 1.0d);
        }
        for (int i7 = 0; i7 < list2.size(); i7++) {
            arrayList2.get(i7).mutableMultiply(1.0d / this.a_k[i7]);
            double[] dArr2 = this.a_k;
            int i8 = i7;
            dArr2[i8] = dArr2[i8] / sampleSize;
        }
        return clusterCompute(i, dataSet, iArr, list2, arrayList2, z2);
    }

    protected double clusterCompute(int i, DataSet dataSet, int[] iArr, List<Vec> list, List<Matrix> list2, boolean z) {
        double eStep;
        List<DataPoint> dataPoints = dataSet.getDataPoints();
        int size = dataPoints.size();
        double d = -1.7976931348623157E308d;
        this.gaussians = new ArrayList(i);
        for (int i2 = 0; i2 < list.size(); i2++) {
            this.gaussians.add(new NormalM(list.get(i2), list2.get(i2)));
        }
        double[][] dArr = new double[dataPoints.size()][i];
        while (true) {
            try {
                eStep = eStep(size, dataPoints, i, dArr, z);
            } catch (InterruptedException | ExecutionException e) {
                Logger.getLogger(EMGaussianMixture.class.getName()).log(Level.SEVERE, (String) null, e);
            }
            if (Math.abs(d - eStep) < this.tolerance) {
                break;
            }
            d = eStep;
            mStep(list, size, dataPoints, i, dArr, list2, z);
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                if (dArr[i3][i4] > dArr[i3][iArr[i3]]) {
                    iArr[i3] = i4;
                }
            }
        }
        return -d;
    }

    private void mStep(List<Vec> list, int i, List<DataPoint> list2, int i2, double[][] dArr, List<Matrix> list3, boolean z) throws InterruptedException {
        list.get(0).length();
        Iterator<Vec> it = list.iterator();
        while (it.hasNext()) {
            it.next().zeroOut();
        }
        Arrays.fill(this.a_k, 0.0d);
        ThreadLocal withInitial = ThreadLocal.withInitial(() -> {
            return new DenseVector(((DataPoint) list2.get(0)).numNumericalValues());
        });
        ParallelUtils.run(z, i, (i3, i4) -> {
            for (int i3 = 0; i3 < i2; i3++) {
                Vec vec = (Vec) withInitial.get();
                vec.zeroOut();
                double d = 0.0d;
                for (int i4 = i3; i4 < i4; i4++) {
                    Vec numericalValues = ((DataPoint) list2.get(i4)).getNumericalValues();
                    d += dArr[i4][i3];
                    vec.mutableAdd(dArr[i4][i3], numericalValues);
                }
                synchronized (((Vec) list.get(i3))) {
                    ((Vec) list.get(i3)).mutableAdd(vec);
                    double[] dArr2 = this.a_k;
                    int i5 = i3;
                    dArr2[i5] = dArr2[i5] + d;
                }
            }
        });
        for (int i5 = 0; i5 < this.a_k.length; i5++) {
            list.get(i5).mutableDivide(this.a_k[i5]);
        }
        Iterator<Matrix> it2 = list3.iterator();
        while (it2.hasNext()) {
            it2.next().zeroOut();
        }
        ParallelUtils.run(z, i, (i6, i7) -> {
            DenseVector denseVector = new DenseVector(((Vec) list.get(0)).length());
            Matrix mo171clone = ((Matrix) list3.get(0)).mo171clone();
            for (int i6 = 0; i6 < i2; i6++) {
                Vec vec = (Vec) list.get(i6);
                denseVector.zeroOut();
                mo171clone.zeroOut();
                for (int i7 = i6; i7 < i7; i7++) {
                    ((DataPoint) list2.get(i7)).getNumericalValues().copyTo(denseVector);
                    denseVector.mutableSubtract(vec);
                    Matrix.OuterProductUpdate(mo171clone, denseVector, denseVector, dArr[i7][i6]);
                }
                synchronized (((Matrix) list3.get(i6))) {
                    ((Matrix) list3.get(i6)).mutableAdd(mo171clone);
                }
            }
        });
        for (int i8 = 0; i8 < i2; i8++) {
            list3.get(i8).mutableMultiply(1.0d / this.a_k[i8]);
        }
        for (int i9 = 0; i9 < i2; i9++) {
            double[] dArr2 = this.a_k;
            int i10 = i9;
            dArr2[i10] = dArr2[i10] / i;
        }
        for (int i11 = 0; i11 < list.size(); i11++) {
            this.gaussians.get(i11).setMeanCovariance(list.get(i11), list3.get(i11));
        }
    }

    private double eStep(int i, List<DataPoint> list, int i2, double[][] dArr, boolean z) throws InterruptedException, ExecutionException {
        return ((Double) ParallelUtils.run(z, i, (i3, i4) -> {
            double d = 0.0d;
            for (int i3 = i3; i3 < i4; i3++) {
                Vec numericalValues = ((DataPoint) list.get(i3)).getNumericalValues();
                double d2 = 0.0d;
                for (int i4 = 0; i4 < i2; i4++) {
                    double pdf = this.a_k[i4] * this.gaussians.get(i4).pdf(numericalValues);
                    dArr[i3][i4] = pdf;
                    d2 += pdf;
                }
                for (int i5 = 0; i5 < i2; i5++) {
                    double[] dArr2 = dArr[i3];
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] / d2;
                }
                d += Math.log(d2);
            }
            return Double.valueOf(d);
        }, (d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        })).doubleValue();
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double logPdf(Vec vec) {
        double pdf = pdf(vec);
        if (pdf == 0.0d) {
            return -1.7976931348623157E308d;
        }
        return Math.log(pdf);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(Vec vec) {
        double d = 0.0d;
        for (int i = 0; i < this.a_k.length; i++) {
            d += this.a_k[i] * this.gaussians.get(i).pdf(vec);
        }
        return d;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list, boolean z) {
        ArrayList arrayList = new ArrayList(list.size());
        Iterator<V> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(new DataPoint(it.next(), new int[0], new CategoricalData[0]));
        }
        return setUsingData(new SimpleDataSet(arrayList), z);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public boolean setUsingData(DataSet dataSet, boolean z) {
        try {
            cluster(dataSet, z);
            return true;
        } catch (ArithmeticException e) {
            return false;
        }
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public EMGaussianMixture clone() {
        return new EMGaussianMixture(this);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public List<Vec> sample(int i, Random random) {
        ArrayList arrayList = new ArrayList(i);
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = random.nextDouble();
        }
        Arrays.sort(dArr);
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        double d = 0.0d;
        while (i4 < this.a_k.length) {
            d += this.a_k[i4];
            while (i5 < i) {
                int i6 = i5;
                i5++;
                if (dArr[i6] < d) {
                    i3++;
                }
            }
            int i7 = i4;
            i4++;
            arrayList.addAll(this.gaussians.get(i7).sample(i3, random));
        }
        return arrayList;
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, int[] iArr) {
        return cluster(dataSet, 2, (int) Math.sqrt(dataSet.getSampleSize() / 2), iArr);
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, boolean z, int[] iArr) {
        return cluster(dataSet, 2, (int) Math.sqrt(dataSet.getSampleSize() / 2), z, iArr);
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, boolean z, int[] iArr) {
        if (iArr == null) {
            iArr = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < i) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        cluster(dataSet, null, i, new ArrayList(i), iArr, false, z, false);
        return iArr;
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, boolean z, int[] iArr) {
        throw new UnsupportedOperationException("EMGaussianMixture does not supported determining the number of clusters");
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, int[] iArr) {
        throw new UnsupportedOperationException("EMGaussianMixture does not supported determining the number of clusters");
    }
}
