/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.KClustererBase;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.random.RandomUtil;

public abstract class KMeans
extends KClustererBase
implements Parameterized {
    private static final long serialVersionUID = 8730927112084289722L;
    public static final SeedSelectionMethods.SeedSelection DEFAULT_SEED_SELECTION = SeedSelectionMethods.SeedSelection.KPP;
    @Parameter.ParameterHolder
    protected DistanceMetric dm;
    protected SeedSelectionMethods.SeedSelection seedSelection;
    protected Random rand;
    protected boolean storeMeans = true;
    protected boolean saveCentroidDistance = true;
    protected double[] nearestCentroidDist;
    protected List<Vec> means;
    protected int MaxIterLimit = Integer.MAX_VALUE;

    public KMeans(DistanceMetric dm, SeedSelectionMethods.SeedSelection seedSelection, Random rand) {
        this.dm = dm;
        this.setSeedSelection(seedSelection);
        this.rand = rand;
    }

    public KMeans(KMeans toCopy) {
        this.dm = toCopy.dm.clone();
        this.seedSelection = toCopy.seedSelection;
        this.rand = RandomUtil.getRandom();
        if (toCopy.nearestCentroidDist != null) {
            this.nearestCentroidDist = Arrays.copyOf(toCopy.nearestCentroidDist, toCopy.nearestCentroidDist.length);
        }
        if (toCopy.means != null) {
            this.means = new ArrayList<Vec>(toCopy.means.size());
            for (Vec v : toCopy.means) {
                this.means.add(v.clone());
            }
        }
    }

    public void setIterationLimit(int iterLimit) {
        if (iterLimit < 1) {
            throw new IllegalArgumentException("Iterations must be a positive value, not " + iterLimit);
        }
        this.MaxIterLimit = iterLimit;
    }

    public int getIterationLimit() {
        return this.MaxIterLimit;
    }

    public void setStoreMeans(boolean storeMeans) {
        this.storeMeans = storeMeans;
    }

    public List<Vec> getMeans() {
        return this.means;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    protected abstract double cluster(DataSet var1, List<Double> var2, int var3, List<Vec> var4, int[] var5, boolean var6, boolean var7, boolean var8, Vec var9);

    protected static List<List<DataPoint>> getListOfLists(int k) {
        ArrayList<List<DataPoint>> ks = new ArrayList<List<DataPoint>>(k);
        for (int i = 0; i < k; ++i) {
            ks.add(new ArrayList());
        }
        return ks;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.getSampleSize() / 2), designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.getSampleSize() / 2), parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < clusters) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        this.means = new ArrayList<Vec>(clusters);
        this.cluster(dataSet, null, clusters, this.means, designations, false, parallel, false, null);
        if (!this.storeMeans) {
            this.means = null;
        }
        return designations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        if (dataSet.getSampleSize() < highK) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        double[] totDistances = new double[highK - lowK + 1];
        List<Double> cache = this.dm.getAccelerationCache(dataSet.getDataVectors(), parallel);
        for (int k = lowK; k <= highK; ++k) {
            totDistances[k - lowK] = this.cluster(dataSet, cache, k, new ArrayList<Vec>(), designations, true, parallel, true, null);
        }
        return this.findK(lowK, highK, totDistances, dataSet, designations);
    }

    private int[] findK(int lowK, int highK, double[] totDistances, DataSet dataSet, int[] designations) {
        OnLineStatistics stats = new OnLineStatistics();
        double maxChange = Double.MIN_VALUE;
        int maxChangeK = lowK;
        for (int i = lowK; i <= highK; ++i) {
            double totDist = totDistances[i - lowK];
            if (i <= lowK) continue;
            double change = Math.abs(totDist - totDistances[i - lowK - 1]);
            stats.add(change);
            if (!(change > maxChange)) continue;
            maxChange = change;
            maxChangeK = i;
        }
        double changeMean = stats.getMean();
        double changeDev = stats.getStandardDeviation();
        if (maxChange < changeDev * 2.0 + changeMean) {
            maxChangeK = lowK;
        } else {
            for (int i = 1; i < totDistances.length; ++i) {
                double d;
                double tmp = Math.abs(totDistances[i] - totDistances[i - 1]);
                if (!(d < maxChange)) continue;
                maxChange = tmp;
                maxChangeK = i + lowK;
                break;
            }
        }
        return this.cluster(dataSet, maxChangeK, designations);
    }

    @Override
    public abstract KMeans clone();

    @Override
    public boolean supportsWeightedData() {
        return true;
    }
}

