package jsat.clustering;

import java.util.Arrays;
import java.util.Random;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.clustering.evaluation.IntraClusterSumEvaluation;
import jsat.clustering.evaluation.intra.SumOfSqrdPairwiseDistances;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/clustering/GapStatistic.class */
public class GapStatistic extends KClustererBase implements Parameterized {
    private static final long serialVersionUID = 8893929177942856618L;

    @Parameter.ParameterHolder
    private KClusterer base;
    private int B;
    private DistanceMetric dm;
    private boolean PCSampling;
    private double[] ElogW;
    private double[] logW;
    private double[] gap;
    private double[] s_k;

    public GapStatistic() {
        this(new HamerlyKMeans());
    }

    public GapStatistic(KClusterer kClusterer) {
        this(kClusterer, false);
    }

    public GapStatistic(KClusterer kClusterer, boolean z) {
        this(kClusterer, z, 10, new EuclideanDistance());
    }

    public GapStatistic(KClusterer kClusterer, boolean z, int i, DistanceMetric distanceMetric) {
        this.base = kClusterer;
        setSamples(i);
        setDistanceMetric(distanceMetric);
        setPCSampling(z);
    }

    public GapStatistic(GapStatistic gapStatistic) {
        this.base = gapStatistic.base.mo114clone();
        this.B = gapStatistic.B;
        this.dm = gapStatistic.dm.mo185clone();
        this.PCSampling = gapStatistic.PCSampling;
        if (gapStatistic.ElogW != null) {
            this.ElogW = Arrays.copyOf(gapStatistic.ElogW, gapStatistic.ElogW.length);
        }
        if (gapStatistic.logW != null) {
            this.logW = Arrays.copyOf(gapStatistic.logW, gapStatistic.logW.length);
        }
        if (gapStatistic.gap != null) {
            this.gap = Arrays.copyOf(gapStatistic.gap, gapStatistic.gap.length);
        }
        if (gapStatistic.s_k != null) {
            this.s_k = Arrays.copyOf(gapStatistic.s_k, gapStatistic.s_k.length);
        }
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setPCSampling(boolean z) {
        this.PCSampling = z;
    }

    public boolean isPCSampling() {
        return this.PCSampling;
    }

    public void setSamples(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("sample size must be positive, not " + i);
        }
        this.B = i;
    }

    public int getSamples() {
        return this.B;
    }

    public double[] getGap() {
        return this.gap;
    }

    public double[] getLogW() {
        return this.logW;
    }

    public double[] getElogW() {
        return this.ElogW;
    }

    public double[] getElogWkStndDev() {
        return this.s_k;
    }

    @Override // jsat.clustering.Clusterer
    public int[] cluster(DataSet dataSet, boolean z, int[] iArr) {
        return cluster(dataSet, 1, (int) Math.min(Math.max(Math.sqrt(dataSet.getSampleSize()), 10.0d), 100.0d), z, iArr);
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, boolean z, int[] iArr) {
        return this.base.cluster(dataSet, i, z, iArr);
    }

    @Override // jsat.clustering.KClusterer
    public int[] cluster(DataSet dataSet, int i, int i2, boolean z, int[] iArr) {
        Matrix matrix;
        int numNumericalVars = dataSet.getNumNumericalVars();
        int sampleSize = dataSet.getSampleSize();
        if (iArr == null || iArr.length < sampleSize) {
            iArr = new int[sampleSize];
        }
        this.logW = new double[i2 - 1];
        this.ElogW = new double[i2 - 1];
        this.gap = new double[i2 - 1];
        this.s_k = new double[i2 - 1];
        IntraClusterSumEvaluation intraClusterSumEvaluation = new IntraClusterSumEvaluation(new SumOfSqrdPairwiseDistances(this.dm));
        Arrays.fill(iArr, 0);
        this.logW[0] = Math.log(intraClusterSumEvaluation.evaluate(iArr, dataSet));
        for (int i3 = 2; i3 < i2; i3++) {
            iArr = this.base.cluster(dataSet, i3, z, iArr);
            this.logW[i3 - 1] = Math.log(intraClusterSumEvaluation.evaluate(iArr, dataSet));
        }
        OnLineStatistics[] onLineStatisticsArr = new OnLineStatistics[i2 - 1];
        for (int i4 = 0; i4 < onLineStatisticsArr.length; i4++) {
            onLineStatisticsArr[i4] = new OnLineStatistics();
        }
        SimpleDataSet simpleDataSet = new SimpleDataSet(new CategoricalData[0], numNumericalVars);
        for (int i5 = 0; i5 < sampleSize; i5++) {
            simpleDataSet.add(new DataPoint(new DenseVector(numNumericalVars)));
        }
        Random random = RandomUtil.getRandom();
        double[] dArr = new double[numNumericalVars];
        double[] dArr2 = new double[numNumericalVars];
        Arrays.fill(dArr, Double.POSITIVE_INFINITY);
        Arrays.fill(dArr2, Double.NEGATIVE_INFINITY);
        if (this.PCSampling) {
            SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(dataSet.getDataMatrix());
            Matrix multiply = dataSet.getDataMatrixView().multiply(singularValueDecomposition.getV());
            for (int i6 = 0; i6 < multiply.rows(); i6++) {
                for (int i7 = 0; i7 < multiply.cols(); i7++) {
                    dArr[i7] = Math.min(multiply.get(i6, i7), dArr[i7]);
                    dArr2[i7] = Math.max(multiply.get(i6, i7), dArr2[i7]);
                }
            }
            matrix = singularValueDecomposition.getV().transpose();
        } else {
            matrix = null;
            OnLineStatistics[] onlineColumnStats = dataSet.getOnlineColumnStats(false);
            for (int i8 = 0; i8 < numNumericalVars; i8++) {
                dArr[i8] = onlineColumnStats[i8].getMin();
                dArr2[i8] = onlineColumnStats[i8].getMax();
            }
        }
        for (int i9 = 0; i9 < this.B; i9++) {
            for (int i10 = 0; i10 < sampleSize; i10++) {
                Vec numericalValues = simpleDataSet.getDataPoint(i10).getNumericalValues();
                for (int i11 = 0; i11 < numNumericalVars; i11++) {
                    numericalValues.set(i11, ((dArr2[i11] - dArr[i11]) * random.nextDouble()) + dArr[i11]);
                }
            }
            if (this.PCSampling) {
                DenseVector denseVector = new DenseVector(numNumericalVars);
                for (int i12 = 0; i12 < sampleSize; i12++) {
                    Vec numericalValues2 = simpleDataSet.getDataPoint(i12).getNumericalValues();
                    denseVector.zeroOut();
                    numericalValues2.multiply(matrix, denseVector);
                    denseVector.copyTo(numericalValues2);
                }
            }
            Arrays.fill(iArr, 0);
            onLineStatisticsArr[0].add(Math.log(intraClusterSumEvaluation.evaluate(iArr, simpleDataSet)));
            for (int i13 = 2; i13 < i2; i13++) {
                iArr = this.base.cluster(simpleDataSet, i13, z, iArr);
                onLineStatisticsArr[i13 - 1].add(Math.log(intraClusterSumEvaluation.evaluate(iArr, simpleDataSet)));
            }
        }
        int i14 = -1;
        int i15 = 0;
        for (int i16 = 0; i16 < this.gap.length; i16++) {
            double mean = onLineStatisticsArr[i16].getMean();
            this.ElogW[i16] = mean;
            this.gap[i16] = mean - this.logW[i16];
            this.s_k[i16] = onLineStatisticsArr[i16].getStandardDeviation() * Math.sqrt(1.0d + (1.0d / this.B));
            int i17 = i16 + 1;
            if (i16 > 0 && i <= i17 && i17 <= i2 && i14 == -1 && this.gap[i16 - 1] >= this.gap[i16] - this.s_k[i16] && this.gap[i16 - 1] > 0.0d) {
                i14 = i17 - 1;
            }
            if (this.gap[i16] > i15 && i <= i17 && i17 <= i2) {
                i15 = i16;
            }
        }
        if (i14 == -1) {
            i14 = i15 + 1;
        }
        if (i14 != 1) {
            return this.base.cluster(dataSet, i14, z, iArr);
        }
        Arrays.fill(iArr, 0);
        return iArr;
    }

    @Override // jsat.clustering.KClustererBase, jsat.clustering.ClustererBase
    /* renamed from: clone */
    public GapStatistic mo114clone() {
        return new GapStatistic(this);
    }
}
