/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.clustering.kmeans.KMeans;
import jsat.distributions.Normal;
import jsat.linear.DenseVector;
import jsat.linear.MatrixStatistics;
import jsat.linear.Vec;

public class GMeans
extends KMeans {
    private static final long serialVersionUID = 7306976407786792661L;
    private boolean trustH0 = true;
    private boolean iterativeRefine = true;
    private int minClusterSize = 25;
    private KMeans kmeans;

    public GMeans() {
        this(new HamerlyKMeans());
    }

    public GMeans(KMeans kmeans) {
        super(kmeans.dm, kmeans.seedSelection, kmeans.rand);
        this.kmeans = kmeans;
        kmeans.setStoreMeans(true);
    }

    public GMeans(GMeans toCopy) {
        super(toCopy);
        this.kmeans = toCopy.kmeans.clone();
        this.trustH0 = toCopy.trustH0;
        this.iterativeRefine = toCopy.iterativeRefine;
        this.minClusterSize = toCopy.minClusterSize;
    }

    public void setTrustH0(boolean trustH0) {
        this.trustH0 = trustH0;
    }

    public boolean getTrustH0() {
        return this.trustH0;
    }

    public void setMinClusterSize(int minClusterSize) {
        if (minClusterSize < 2) {
            throw new IllegalArgumentException("min cluster size that could be split is 2, not " + minClusterSize);
        }
        this.minClusterSize = minClusterSize;
    }

    public int getMinClusterSize() {
        return this.minClusterSize;
    }

    public void setIterativeRefine(boolean refineCenters) {
        this.iterativeRefine = refineCenters;
    }

    public boolean getIterativeRefine() {
        return this.iterativeRefine;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, 1, Math.max(dataSet.getSampleSize() / 20, 10), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 1, Math.max(dataSet.getSampleSize() / 20, 10), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        int origMeans;
        int N = dataSet.getSampleSize();
        if (lowK >= 2) {
            designations = this.kmeans.cluster(dataSet, lowK, parallel, designations);
            this.means = new ArrayList<Vec>(this.kmeans.getMeans());
        } else {
            if (designations == null || designations.length < N) {
                designations = new int[N];
            } else {
                Arrays.fill(designations, 0);
            }
            this.means = new ArrayList<Vec>(Arrays.asList(MatrixStatistics.meanVector(dataSet)));
        }
        int[] subS = new int[designations.length];
        int[] subC = new int[designations.length];
        DenseVector v = new DenseVector(dataSet.getNumNumericalVars());
        double[] xp = new double[N];
        ArrayList<Boolean> dontRedo = new ArrayList<Boolean>(Collections.nCopies(this.means.size(), false));
        List<Double> accelCache = this.dm.getAccelerationCache(dataSet.getDataVectors(), parallel);
        double thresh = 1.8692;
        do {
            origMeans = this.means.size();
            for (int c = 0; c < origMeans; ++c) {
                int i;
                if (((Boolean)dontRedo.get(c)).booleanValue()) continue;
                List<DataPoint> X = GMeans.getDatapointsFromCluster(c, designations, dataSet, subS);
                int n = X.size();
                if (X.size() < this.minClusterSize || this.means.size() == highK) continue;
                SimpleDataSet subSet = new SimpleDataSet(X);
                subC = this.kmeans.cluster((DataSet)subSet, 2, parallel, subC);
                List<Vec> subMean = this.kmeans.getMeans();
                Vec c1 = subMean.get(0);
                Vec c2 = subMean.get(1);
                c1.copyTo(v);
                v.mutableSubtract(c2);
                double vNrmSqrd = Math.pow(((Vec)v).pNorm(2.0), 2.0);
                if (Double.isNaN(vNrmSqrd) || vNrmSqrd < 1.0E-6) continue;
                for (int i2 = 0; i2 < X.size(); ++i2) {
                    xp[i2] = X.get(i2).getNumericalValues().dot(v) / vNrmSqrd;
                }
                Arrays.sort(xp, 0, X.size());
                DenseVector Xp = new DenseVector(xp, 0, X.size());
                Xp.mutableSubtract(Xp.mean());
                Xp.mutableDivide(Math.max(Xp.standardDeviation(), 1.0E-6));
                for (int i3 = 0; i3 < Xp.length(); ++i3) {
                    Xp.set(i3, Normal.cdf(Xp.get(i3), 0.0, 1.0));
                }
                double A = 0.0;
                for (i = 1; i <= Xp.length(); ++i) {
                    double phi = Xp.get(i - 1);
                    A += (double)(2 * i - 1) * Math.log(phi) + (double)(2 * (n - i) + 1) * Math.log(1.0 - phi);
                }
                A /= (double)(-n);
                A += (double)(-n);
                if ((A *= 1.0 + 4.0 / (double)n - 25.0 / (double)(n * n)) <= thresh) {
                    if (!this.trustH0) continue;
                    dontRedo.set(c, true);
                    continue;
                }
                for (i = 0; i < X.size(); ++i) {
                    if (subC[i] != 1) continue;
                    designations[subS[i]] = this.means.size();
                }
                this.means.set(c, c1.clone());
                this.means.add(c2.clone());
                dontRedo.add(false);
            }
            if (!this.iterativeRefine || this.means.size() <= 1) continue;
            this.kmeans.cluster(dataSet, accelCache, this.means.size(), this.means, designations, false, parallel, false, null);
        } while (origMeans < this.means.size());
        if (!this.iterativeRefine && this.means.size() > 1) {
            this.kmeans.cluster(dataSet, accelCache, this.means.size(), this.means, designations, false, parallel, false, null);
        }
        return designations;
    }

    @Override
    public int getIterationLimit() {
        return this.kmeans.getIterationLimit();
    }

    @Override
    public void setIterationLimit(int iterLimit) {
        this.kmeans.setIterationLimit(iterLimit);
    }

    @Override
    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        if (this.kmeans != null) {
            this.kmeans.setSeedSelection(seedSelection);
        }
    }

    @Override
    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.kmeans.getSeedSelection();
    }

    @Override
    protected double cluster(DataSet dataSet, List<Double> accelCache, int k, List<Vec> means, int[] assignment, boolean exactTotal, boolean threadpool, boolean returnError, Vec dataPointWeights) {
        return this.kmeans.cluster(dataSet, accelCache, k, means, assignment, exactTotal, threadpool, returnError, null);
    }

    @Override
    public GMeans clone() {
        return new GMeans(this);
    }
}

