/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.kmeans;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.DataSet;
import jsat.clustering.KClustererBase;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.ConstantVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public abstract class KernelKMeans
extends KClustererBase
implements Parameterized {
    private static final long serialVersionUID = -5294680202634779440L;
    @Parameter.ParameterHolder
    protected KernelTrick kernel;
    protected List<Vec> X;
    protected Vec W;
    protected List<Double> accel;
    protected double[] selfK;
    protected double[] meanSqrdNorms;
    protected double[] normConsts;
    protected double[] ownes;
    protected int[] newDesignations;
    protected int maximumIterations = Integer.MAX_VALUE;

    public KernelKMeans(KernelTrick kernel) {
        this.kernel = kernel;
    }

    public KernelKMeans(KernelKMeans toCopy) {
        this.kernel = toCopy.kernel.clone();
        this.maximumIterations = toCopy.maximumIterations;
        if (toCopy.X != null) {
            this.X = new ArrayList<Vec>(toCopy.X.size());
            for (Vec v : toCopy.X) {
                this.X.add(v.clone());
            }
        }
        if (toCopy.accel != null) {
            this.accel = new DoubleList(toCopy.accel);
        }
        if (toCopy.selfK != null) {
            this.selfK = Arrays.copyOf(toCopy.selfK, toCopy.selfK.length);
        }
        if (toCopy.meanSqrdNorms != null) {
            this.meanSqrdNorms = Arrays.copyOf(toCopy.meanSqrdNorms, toCopy.meanSqrdNorms.length);
        }
        if (toCopy.normConsts != null) {
            this.normConsts = Arrays.copyOf(toCopy.normConsts, toCopy.normConsts.length);
        }
        if (toCopy.ownes != null) {
            this.ownes = Arrays.copyOf(toCopy.ownes, toCopy.ownes.length);
        }
        if (toCopy.newDesignations != null) {
            this.newDesignations = Arrays.copyOf(toCopy.newDesignations, toCopy.newDesignations.length);
        }
        if (toCopy.W != null) {
            this.W = toCopy.W.clone();
        }
    }

    public void setMaximumIterations(int iterLimit) {
        if (iterLimit <= 0) {
            throw new IllegalArgumentException("iterations must be a positive value, not " + iterLimit);
        }
        this.maximumIterations = iterLimit;
    }

    public int getMaximumIterations() {
        return this.maximumIterations;
    }

    @Override
    public int[] cluster(DataSet dataSet, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, int[] designations) {
        throw new UnsupportedOperationException("Not supported.");
    }

    protected double evalSumK(int i, int clusterID, int[] d) {
        double sum = 0.0;
        for (int j = 0; j < this.X.size(); ++j) {
            if (d[j] != clusterID) continue;
            sum += this.W.get(j) * this.kernel.eval(i, j, this.X, this.accel);
        }
        return sum;
    }

    protected double evalSumK(Vec x, List<Double> qi, int clusterID, int[] d) {
        double sum = 0.0;
        for (int j = 0; j < this.X.size(); ++j) {
            if (d[j] != clusterID) continue;
            sum += this.W.get(j) * this.kernel.eval(j, x, qi, this.X, this.accel);
        }
        return sum;
    }

    protected void setup(int K, int[] designations, Vec W) {
        int i;
        this.accel = this.kernel.getAccelerationCache(this.X);
        int N = this.X.size();
        this.selfK = new double[N];
        for (int i2 = 0; i2 < this.selfK.length; ++i2) {
            this.selfK[i2] = this.kernel.eval(i2, i2, this.X, this.accel);
        }
        this.ownes = new double[K];
        this.meanSqrdNorms = new double[K];
        this.newDesignations = new int[N];
        this.W = W == null ? new ConstantVector(1.0, N) : W;
        Random rand = RandomUtil.getRandom();
        for (i = 0; i < N; ++i) {
            int to;
            int n = to = rand.nextInt(K);
            this.ownes[n] = this.ownes[n] + this.W.get(i);
            this.newDesignations[i] = designations[i] = to;
        }
        this.normConsts = new double[K];
        this.updateNormConsts();
        for (i = 0; i < N; ++i) {
            int i_k = designations[i];
            double w_i = this.W.get(i);
            int n = i_k;
            this.meanSqrdNorms[n] = this.meanSqrdNorms[n] + w_i * this.selfK[i];
            for (int j = i + 1; j < N; ++j) {
                if (i_k != designations[j]) continue;
                int n2 = i_k;
                this.meanSqrdNorms[n2] = this.meanSqrdNorms[n2] + 2.0 * w_i * this.W.get(j) * this.kernel.eval(i, j, this.X, this.accel);
            }
        }
    }

    protected void updateNormConsts() {
        for (int i = 0; i < this.normConsts.length; ++i) {
            this.normConsts[i] = 1.0 / (this.ownes[i] * this.ownes[i]);
        }
    }

    protected double distance(int i, int k, int[] designations) {
        return Math.sqrt(Math.max(this.selfK[i] - 2.0 / this.ownes[k] * this.evalSumK(i, k, designations) + this.meanSqrdNorms[k] * this.normConsts[k], 0.0));
    }

    public double distance(Vec x, int k) {
        return this.distance(x, this.kernel.getQueryInfo(x), k);
    }

    public double distance(Vec x, List<Double> qi, int k) {
        if (k >= this.meanSqrdNorms.length || k < 0) {
            throw new IndexOutOfBoundsException("Only " + this.meanSqrdNorms.length + " clusters. " + k + " is not a valid index");
        }
        return Math.sqrt(Math.max(this.kernel.eval(0, 0, Arrays.asList(x), qi) - 2.0 / this.ownes[k] * this.evalSumK(x, qi, k, this.newDesignations) + this.meanSqrdNorms[k] * this.normConsts[k], 0.0));
    }

    public int findClosestCluster(Vec x) {
        return this.findClosestCluster(x, this.kernel.getQueryInfo(x));
    }

    public int findClosestCluster(Vec x, List<Double> qi) {
        double min = Double.MAX_VALUE;
        int min_indx = -1;
        for (int i = 0; i < this.meanSqrdNorms.length; ++i) {
            double dist = this.distance(x, qi, i);
            if (!(dist < min)) continue;
            min = dist;
            min_indx = i;
        }
        return min_indx;
    }

    protected int updateMeansFromChange(int i, int[] designations) {
        return this.updateMeansFromChange(i, designations, this.meanSqrdNorms, this.ownes);
    }

    protected int updateMeansFromChange(int i, int[] designations, double[] sqrdNorms, double[] ownership) {
        int old_d = designations[i];
        int new_d = this.newDesignations[i];
        if (old_d == new_d) {
            return 0;
        }
        int N = this.X.size();
        double w_i = this.W.get(i);
        int n = old_d;
        ownership[n] = ownership[n] - w_i;
        int n2 = new_d;
        ownership[n2] = ownership[n2] + w_i;
        for (int j = 0; j < N; ++j) {
            double w_j = this.W.get(j);
            int oldD_j = designations[j];
            int newD_j = this.newDesignations[j];
            if (i == j) {
                int n3 = old_d;
                sqrdNorms[n3] = sqrdNorms[n3] - w_i * this.selfK[i];
                int n4 = new_d;
                sqrdNorms[n4] = sqrdNorms[n4] + w_i * this.selfK[i];
                continue;
            }
            if (old_d == oldD_j && (i <= j || oldD_j == newD_j)) {
                int n5 = old_d;
                sqrdNorms[n5] = sqrdNorms[n5] - 2.0 * w_i * w_j * this.kernel.eval(i, j, this.X, this.accel);
            }
            if (new_d != newD_j || i > j && oldD_j != newD_j) continue;
            int n6 = new_d;
            sqrdNorms[n6] = sqrdNorms[n6] + 2.0 * w_i * w_j * this.kernel.eval(i, j, this.X, this.accel);
        }
        return 1;
    }

    protected void applyMeanUpdates(double[] sqrdNorms, double[] ownerships) {
        for (int i = 0; i < sqrdNorms.length; ++i) {
            int n = i;
            this.meanSqrdNorms[n] = this.meanSqrdNorms[n] + sqrdNorms[i];
            int n2 = i;
            this.ownes[n2] = this.ownes[n2] + ownerships[i];
        }
    }

    public double meanToMeanDistance(int k0, int k1) {
        if (k0 >= this.meanSqrdNorms.length || k0 < 0) {
            throw new IndexOutOfBoundsException("Only " + this.meanSqrdNorms.length + " clusters. " + k0 + " is not a valid index");
        }
        if (k1 >= this.meanSqrdNorms.length || k1 < 0) {
            throw new IndexOutOfBoundsException("Only " + this.meanSqrdNorms.length + " clusters. " + k1 + " is not a valid index");
        }
        return this.meanToMeanDistance(k0, k1, this.newDesignations);
    }

    protected double meanToMeanDistance(int k0, int k1, int[] assignments) {
        double d = this.meanSqrdNorms[k0] * this.normConsts[k0] + this.meanSqrdNorms[k1] * this.normConsts[k1] - 2.0 * this.dot(k0, k1, assignments);
        return Math.sqrt(Math.max(0.0, d));
    }

    protected double meanToMeanDistance(int k0, int k1, int[] assignments, boolean parallel) {
        double d = this.meanSqrdNorms[k0] * this.normConsts[k0] + this.meanSqrdNorms[k1] * this.normConsts[k1] - 2.0 * this.dot(k0, k1, assignments, parallel);
        return Math.sqrt(Math.max(0.0, d));
    }

    protected double meanToMeanDistance(int k0, int k1, int[] assignments0, int[] assignments1, double k1SqrdNorm) {
        double d = this.meanSqrdNorms[k0] * this.normConsts[k0] + k1SqrdNorm - 2.0 * this.dot(k0, k1, assignments0, assignments1);
        return Math.sqrt(Math.max(0.0, d));
    }

    protected double meanToMeanDistance(int k0, int k1, int[] assignments0, int[] assignments1, double k1SqrdNorm, boolean parallel) {
        double d = this.meanSqrdNorms[k0] * this.normConsts[k0] + k1SqrdNorm - 2.0 * this.dot(k0, k1, assignments0, assignments1, parallel);
        return Math.sqrt(Math.max(0.0, d));
    }

    private double dot(int k0, int k1, int[] assignment) {
        return this.dot(k0, k1, assignment, assignment);
    }

    private double dot(int k0, int k1, int[] assignment, boolean parallel) {
        return this.dot(k0, k1, assignment, assignment, parallel);
    }

    private double dot(int k0, int k1, int[] assignment0, int[] assignment1) {
        double dot = 0.0;
        int N = this.X.size();
        double a = 0.0;
        double b = 0.0;
        for (int i = 0; i < N; ++i) {
            double w_i = this.W.get(i);
            if (assignment0[i] != k0) continue;
            a += w_i;
            for (int j = 0; j < N; ++j) {
                if (assignment1[j] != k1) continue;
                double w_j = this.W.get(j);
                dot += w_i * w_j * this.kernel.eval(i, j, this.X, this.accel);
            }
        }
        for (int j = 0; j < N; ++j) {
            if (assignment1[j] != k1) continue;
            b += this.W.get(j);
        }
        return dot / (a * b);
    }

    private double dot(int k0, int k1, int[] assignment0, int[] assignment1, boolean parallel) {
        double dot = 0.0;
        int N = this.X.size();
        double a = 0.0;
        double b = 0.0;
        ParallelUtils.run(parallel, N, i -> {
            double w_i = this.W.get(i);
            if (assignment0[i] != k0) {
                return 0.0;
            }
            double localDot = 0.0;
            for (int j = 0; j < N; ++j) {
                if (assignment1[j] != k1) continue;
                double w_j = this.W.get(j);
                localDot += w_i * w_j * this.kernel.eval(i, j, this.X, this.accel);
            }
            return localDot;
        }, (t, u) -> t + u);
        a = this.W.sum();
        for (int j = 0; j < N; ++j) {
            if (assignment1[j] != k1) continue;
            b += this.W.get(j);
        }
        return dot / (a * b);
    }

    @Override
    public abstract KernelKMeans clone();

    @Override
    public boolean supportsWeightedData() {
        return true;
    }
}

