/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;
import jsat.DataSet;
import jsat.clustering.ClusterFailureException;
import jsat.clustering.kmeans.KernelKMeans;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.linear.Vec;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

public class ElkanKernelKMeans
extends KernelKMeans {
    private static final long serialVersionUID = 4998832201379993827L;
    private double[][] centroidSelfDistances;
    private double[][] centroidPairDots;

    public ElkanKernelKMeans(KernelTrick kernel) {
        super(kernel);
    }

    public ElkanKernelKMeans(ElkanKernelKMeans toCopy) {
        super(toCopy);
    }

    @Override
    public int findClosestCluster(Vec x, List<Double> qi) {
        double min = Double.MAX_VALUE;
        int min_indx = -1;
        boolean[] pruned = new boolean[this.meanSqrdNorms.length];
        Arrays.fill(pruned, false);
        for (int i = 0; i < this.meanSqrdNorms.length; ++i) {
            if (this.ownes[i] <= 1.0E-15 || pruned[i]) continue;
            double dist = this.distance(x, qi, i);
            if (dist < min) {
                min = dist;
                min_indx = i;
            }
            for (int j = i + 1; j < this.meanSqrdNorms.length; ++j) {
                if (!(this.centroidSelfDistances[i][j] >= 2.0 * dist)) continue;
                pruned[j] = true;
            }
        }
        return min_indx;
    }

    private void update_centroid_pair_dots(int[] prev_assignments, int[] new_assignments, boolean parallel) {
        int N = this.X.size();
        ParallelUtils.run(parallel, N, (start, end) -> {
            double[][] localChanges = new double[this.centroidPairDots.length][this.centroidPairDots.length];
            int i = start;
            while (true) {
                if (i >= end) break;
                double w_i = this.W.get(i);
                int old_c_i = prev_assignments[i];
                int new_c_i = new_assignments[i];
                for (int j = i; j < N; ++j) {
                    int old_c_j = prev_assignments[j];
                    int new_c_j = new_assignments[j];
                    if (old_c_i == new_c_i && old_c_j == new_c_j) continue;
                    double w_j = this.W.get(j);
                    double K_ij = w_i * w_j * this.kernel.eval(i, j, this.X, this.accel);
                    if (old_c_i >= 0 && old_c_j >= 0) {
                        double[] dArray = localChanges[old_c_i];
                        int n = old_c_j;
                        dArray[n] = dArray[n] - K_ij;
                        double[] dArray2 = localChanges[old_c_j];
                        int n2 = old_c_i;
                        dArray2[n2] = dArray2[n2] - K_ij;
                    }
                    double[] dArray = localChanges[new_c_i];
                    int n = new_c_j;
                    dArray[n] = dArray[n] + K_ij;
                    double[] dArray3 = localChanges[new_c_j];
                    int n3 = new_c_i;
                    dArray3[n3] = dArray3[n3] + K_ij;
                }
                ++i;
            }
            i = 0;
            while (i < localChanges.length) {
                double[] centroidPairDots_i;
                double[] dArray = centroidPairDots_i = this.centroidPairDots[i];
                // MONITORENTER : centroidPairDots_i
                for (int j = 0; j < localChanges[i].length; ++j) {
                    int n = j;
                    centroidPairDots_i[n] = centroidPairDots_i[n] + localChanges[i][j];
                }
                // MONITOREXIT : dArray
                ++i;
            }
        });
    }

    protected double cluster(DataSet dataSet, int k, int[] assignment, boolean exactTotal, boolean parallel) {
        try {
            int N = dataSet.getSampleSize();
            if (N < k) {
                throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
            }
            this.X = dataSet.getDataVectors();
            this.setup(k, assignment, dataSet.getDataWeights());
            double[][] lowerBound = new double[N][k];
            double[] upperBound = new double[N];
            this.centroidSelfDistances = new double[k][k];
            this.centroidPairDots = new double[k][k];
            double[] sC = new double[k];
            this.calculateCentroidDistances(k, this.centroidSelfDistances, sC, assignment, null, parallel);
            int[] prev_assignment = new int[N];
            int atLeast = 2;
            AtomicBoolean changeOccurred = new AtomicBoolean(true);
            boolean[] r = new boolean[N];
            this.initialClusterSetUp(k, N, lowerBound, upperBound, this.centroidSelfDistances, assignment, parallel);
            int iterLimit = this.maximumIterations;
            while ((changeOccurred.get() || atLeast > 0) && iterLimit-- >= 0) {
                --atLeast;
                changeOccurred.set(false);
                if (iterLimit < this.maximumIterations - 1) {
                    this.calculateCentroidDistances(k, this.centroidSelfDistances, sC, assignment, prev_assignment, parallel);
                }
                System.arraycopy(assignment, 0, prev_assignment, 0, assignment.length);
                CountDownLatch latch = new CountDownLatch(SystemInfo.LogicalCores);
                ParallelUtils.run(parallel, N, q -> {
                    if (upperBound[q] <= sC[assignment[q]]) {
                        return;
                    }
                    for (int c = 0; c < k; ++c) {
                        if (c == assignment[q] || !(upperBound[q] > lowerBound[q][c]) || !(upperBound[q] > this.centroidSelfDistances[assignment[q]][c] * 0.5)) continue;
                        this.step3aBoundsUpdate(r, q, assignment, upperBound, lowerBound);
                        this.step3bUpdate(upperBound, q, lowerBound, c, this.centroidSelfDistances, assignment, changeOccurred);
                    }
                });
                int n = this.step4_5_6_distanceMovedBoundsUpdate(k, N, lowerBound, upperBound, assignment, r, parallel);
            }
            double totalDistance = 0.0;
            if (exactTotal) {
                for (int i = 0; i < N; ++i) {
                    totalDistance += Math.pow(upperBound[i], 2.0);
                }
            } else {
                for (int i = 0; i < N; ++i) {
                    totalDistance += Math.pow(upperBound[i], 2.0);
                }
            }
            return totalDistance;
        }
        catch (Exception ex) {
            ex.printStackTrace();
            throw new FailedToFitException(ex);
        }
    }

    private void initialClusterSetUp(int k, int N, double[][] lowerBound, double[] upperBound, double[][] centroidSelfDistances, int[] assignment, boolean parallel) {
        ParallelUtils.run(parallel, N, (start, end) -> {
            boolean[] skip = new boolean[k];
            for (int q = start; q < end; ++q) {
                double minDistance = Double.MAX_VALUE;
                int index = -1;
                Arrays.fill(skip, false);
                for (int i = 0; i < k; ++i) {
                    double d;
                    if (skip[i]) continue;
                    lowerBound[q][i] = d = this.distance(q, i, assignment);
                    if (!(d < minDistance)) continue;
                    minDistance = upperBound[q] = d;
                    index = i;
                    for (int z = i + 1; z < k; ++z) {
                        if (!(centroidSelfDistances[i][z] >= 2.0 * d)) continue;
                        skip[z] = true;
                    }
                }
                this.newDesignations[q] = index;
            }
        });
    }

    private int step4_5_6_distanceMovedBoundsUpdate(int k, int N, double[][] lowerBound, double[] upperBound, int[] assignment, boolean[] r, boolean parallel) {
        double[] distancesMoved = new double[k];
        double[] oldSqrdNorms = new double[this.meanSqrdNorms.length];
        for (int i2 = 0; i2 < this.meanSqrdNorms.length; ++i2) {
            oldSqrdNorms[i2] = this.meanSqrdNorms[i2] * this.normConsts[i2];
        }
        int moved = ParallelUtils.run(parallel, N, (start, end) -> {
            double[] sqrdChange = new double[k];
            double[] ownerChange = new double[k];
            int localChange = 0;
            for (int q = start; q < end; ++q) {
                localChange += this.updateMeansFromChange(q, assignment, sqrdChange, ownerChange);
            }
            int[] nArray = assignment;
            synchronized (assignment) {
                this.applyMeanUpdates(sqrdChange, ownerChange);
                // ** MonitorExit[var8_9] (shouldn't be in output)
                return localChange;
            }
        }, (t, u) -> t + u);
        this.updateNormConsts();
        ParallelUtils.run(parallel, k, i -> {
            distancesMoved[i] = this.meanToMeanDistance(i, i, this.newDesignations, assignment, oldSqrdNorms[i], parallel);
        });
        ParallelUtils.run(parallel, k, c -> {
            for (int q = 0; q < N; ++q) {
                lowerBound[q][c] = Math.max(lowerBound[q][c] - distancesMoved[c], 0.0);
            }
        });
        System.arraycopy(this.newDesignations, 0, assignment, 0, N);
        ParallelUtils.run(parallel, N, (start, end) -> {
            for (int q = start; q < end; ++q) {
                int n = q;
                upperBound[n] = upperBound[n] + distancesMoved[assignment[q]];
                r[q] = true;
            }
        });
        return moved;
    }

    private void step3aBoundsUpdate(boolean[] r, int q, int[] assignment, double[] upperBound, double[][] lowerBound) {
        if (r[q]) {
            double d;
            r[q] = false;
            int meanIndx = assignment[q];
            lowerBound[q][meanIndx] = d = this.distance(q, meanIndx, assignment);
            upperBound[q] = d;
        }
    }

    private void step3bUpdate(double[] upperBound, int q, double[][] lowerBound, int c, double[][] centroidSelfDistances, int[] assignment, AtomicBoolean changeOccurred) {
        if (upperBound[q] > lowerBound[q][c] || upperBound[q] > centroidSelfDistances[assignment[q]][c] / 2.0) {
            double d;
            lowerBound[q][c] = d = this.distance(q, c, assignment);
            if (d < upperBound[q]) {
                this.newDesignations[q] = c;
                upperBound[q] = d;
                changeOccurred.lazySet(true);
            }
        }
    }

    private void calculateCentroidDistances(int k, double[][] centroidSelfDistances, double[] sC, int[] curAssignments, int[] prev_assignments, boolean parallel) {
        int i;
        if (prev_assignments == null) {
            prev_assignments = new int[curAssignments.length];
            Arrays.fill(prev_assignments, -1);
        }
        int[] prev_assing = prev_assignments;
        this.update_centroid_pair_dots(prev_assing, curAssignments, parallel);
        double[] weight_per_cluster = new double[k];
        for (i = 0; i < curAssignments.length; ++i) {
            int n = curAssignments[i];
            weight_per_cluster[n] = weight_per_cluster[n] + this.W.get(i);
        }
        for (i = 0; i < k; ++i) {
            for (int z = i + 1; z < k; ++z) {
                double dot = this.centroidPairDots[i][z];
                double d = this.meanSqrdNorms[i] * this.normConsts[i] + this.meanSqrdNorms[z] * this.normConsts[z] - 2.0 * (dot /= weight_per_cluster[i] * weight_per_cluster[z]);
                double d2 = Math.sqrt(Math.max(0.0, d));
                centroidSelfDistances[i][z] = d2;
                centroidSelfDistances[z][i] = d2;
            }
        }
        for (i = 0; i < k; ++i) {
            double sCmin = Double.MAX_VALUE;
            for (int z = 0; z < k; ++z) {
                if (i == z) continue;
                sCmin = Math.min(sCmin, centroidSelfDistances[i][z]);
            }
            sC[i] = sCmin / 2.0;
        }
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        if (dataSet.getSampleSize() < clusters) {
            throw new ClusterFailureException("Fewer data points then desired clusters, decrease cluster size");
        }
        this.cluster(dataSet, clusters, designations, false, parallel);
        return designations;
    }

    @Override
    public ElkanKernelKMeans clone() {
        return new ElkanKernelKMeans(this);
    }
}

