/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering;

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.concurrent.atomic.AtomicIntegerArray;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
import jsat.DataSet;
import jsat.clustering.PAM;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.math.OnLineStatistics;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDoubleArray;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class MEDDIT
extends PAM {
    private double tolerance = 0.01;

    public MEDDIT(DistanceMetric dm, Random rand, SeedSelectionMethods.SeedSelection seedSelection) {
        super(dm, rand, seedSelection);
    }

    public MEDDIT(DistanceMetric dm, Random rand) {
        super(dm, rand);
    }

    public MEDDIT(DistanceMetric dm) {
        super(dm);
    }

    public MEDDIT() {
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    @Override
    protected double cluster(DataSet data, boolean doInit, int[] medioids, int[] assignments, List<Double> cacheAccel, boolean parallel) {
        List<Double> accel;
        DoubleAdder totalDistance = new DoubleAdder();
        LongAdder changes = new LongAdder();
        Arrays.fill(assignments, -1);
        List<Vec> X = data.getDataVectors();
        int N = data.getSampleSize();
        if (doInit) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, data);
            accel = this.dm.getAccelerationCache(X);
            SeedSelectionMethods.selectIntialPoints(data, medioids, this.dm, accel, this.rand, this.seedSelection);
        } else {
            accel = cacheAccel;
        }
        double tol = this.tolerance < 0.0 ? 1.0 / (double)data.getSampleSize() : this.tolerance;
        int iter = 0;
        do {
            changes.reset();
            totalDistance.reset();
            ParallelUtils.run(parallel, N, (start, end) -> {
                for (int i = start; i < end; ++i) {
                    int assignment = 0;
                    double minDist = this.dm.dist(medioids[0], i, (List<? extends Vec>)X, accel);
                    for (int k = 1; k < medioids.length; ++k) {
                        double dist = this.dm.dist(medioids[k], i, (List<? extends Vec>)X, accel);
                        if (!(dist < minDist)) continue;
                        minDist = dist;
                        assignment = k;
                    }
                    if (assignments[i] != assignment) {
                        changes.increment();
                        assignments[i] = assignment;
                    }
                    totalDistance.add(minDist * minDist);
                }
            });
            IntList owned_by_k = new IntList(N);
            for (int k = 0; k < medioids.length; ++k) {
                owned_by_k.clear();
                for (int i = 0; i < N; ++i) {
                    if (assignments[i] != k) continue;
                    owned_by_k.add(i);
                }
                if (owned_by_k.isEmpty()) continue;
                medioids[k] = MEDDIT.medoid(parallel, owned_by_k, tol, X, this.dm, accel);
            }
        } while (changes.sum() > 0L && iter++ < this.iterLimit);
        return totalDistance.sum();
    }

    public static int medoid(boolean parallel, List<? extends Vec> X, DistanceMetric dm) {
        return MEDDIT.medoid(parallel, X, 1.0 / (double)X.size(), dm);
    }

    public static int medoid(boolean parallel, List<? extends Vec> X, double tol, DistanceMetric dm) {
        IntList order = new IntList(X.size());
        ListUtils.addRange(order, 0, X.size(), 1);
        List<Double> accel = dm.getAccelerationCache(X, parallel);
        return MEDDIT.medoid(parallel, order, tol, X, dm, accel);
    }

    public static int medoid(boolean parallel, Collection<Integer> indecies, List<? extends Vec> X, DistanceMetric dm, List<Double> accel) {
        return MEDDIT.medoid(parallel, indecies, 1.0 / (double)indecies.size(), X, dm, accel);
    }

    public static int medoid(boolean parallel, Collection<Integer> indecies, double tol, List<? extends Vec> X, DistanceMetric dm, List<Double> accel) {
        int samples;
        int num_to_pull;
        int N = indecies.size();
        if (tol <= 0.0 || N < SystemInfo.LogicalCores) {
            return PAM.medoid(parallel, indecies, X, dm, accel);
        }
        double log2d = Math.log(1.0) - Math.log(tol);
        AtomicDoubleArray totalDistSum = new AtomicDoubleArray(N);
        AtomicIntegerArray totalDistCount = new AtomicIntegerArray(N);
        int[] indx_map = indecies.stream().mapToInt(i -> i).toArray();
        boolean symetric = dm.isSymmetric();
        double[] lower_bound_est = new double[N];
        double[] upper_bound_est = new double[N];
        ThreadLocal<Random> localRand = ThreadLocal.withInitial(RandomUtil::getRandom);
        OnLineStatistics distanceStats = ParallelUtils.run(parallel, N, (start, end) -> {
            Random rand = (Random)localRand.get();
            OnLineStatistics localStats = new OnLineStatistics();
            for (int i = start; i < end; ++i) {
                int j = rand.nextInt(N);
                while (j == i) {
                    j = rand.nextInt(N);
                }
                double d_ij = dm.dist(indx_map[i], indx_map[j], X, accel);
                localStats.add(d_ij);
                totalDistSum.addAndGet(i, d_ij);
                totalDistCount.incrementAndGet(i);
                if (!symetric) continue;
                totalDistSum.addAndGet(j, d_ij);
                totalDistCount.incrementAndGet(j);
            }
            return localStats;
        }, (a, b) -> OnLineStatistics.add(a, b));
        ConcurrentSkipListSet<Integer> lowerQ = new ConcurrentSkipListSet<Integer>((o1, o2) -> {
            int cmp = Double.compare(lower_bound_est[o1], lower_bound_est[o2]);
            if (cmp == 0) {
                cmp = o1.compareTo((Integer)o2);
            }
            return cmp;
        });
        ConcurrentSkipListSet<Integer> upperQ = new ConcurrentSkipListSet<Integer>((o1, o2) -> {
            int cmp = Double.compare(upper_bound_est[o1], upper_bound_est[o2]);
            if (cmp == 0) {
                cmp = o1.compareTo((Integer)o2);
            }
            return cmp;
        });
        ParallelUtils.run(parallel, N, (start, end) -> {
            double v = distanceStats.getVarance();
            for (int i = start; i < end; ++i) {
                int T_i = totalDistCount.get(i);
                double c_i = Math.sqrt(2.0 * v * log2d / (double)T_i);
                lower_bound_est[i] = totalDistSum.get(i) / (double)T_i - c_i;
                upper_bound_est[i] = totalDistSum.get(i) / (double)T_i + c_i;
                lowerQ.add(i);
                upperQ.add(i);
            }
        });
        if (parallel) {
            num_to_pull = Math.max(SystemInfo.LogicalCores, 32);
            samples = Math.min(32, N - 1);
        } else {
            num_to_pull = Math.min(32, N);
            samples = Math.min(32, N - 1);
        }
        IntList to_pull = new IntList();
        IntList toAddBack = new IntList();
        boolean[] isExact = new boolean[N];
        Arrays.fill(isExact, false);
        int numExact = 0;
        while (numExact < N) {
            to_pull.clear();
            toAddBack.clear();
            if (upper_bound_est[(Integer)upperQ.first()] < lower_bound_est[(Integer)lowerQ.first()]) {
                return indx_map[(Integer)upperQ.first()];
            }
            while (to_pull.size() < num_to_pull && !lowerQ.isEmpty()) {
                int i2 = (Integer)lowerQ.pollFirst();
                if (totalDistCount.get(i2) >= N - 1 && !isExact[i2]) {
                    double avg_d_i = ParallelUtils.run(parallel, N, (start, end) -> {
                        double d = 0.0;
                        for (int j = start; j < end; ++j) {
                            if (i2 == j) continue;
                            d += dm.dist(indx_map[i2], indx_map[j], X, accel);
                        }
                        return d;
                    }, (a, b) -> a + b);
                    upperQ.remove(i2);
                    lower_bound_est[i2] = upper_bound_est[i2] = (avg_d_i /= (double)(N - 1));
                    totalDistSum.set(i2, avg_d_i);
                    totalDistCount.set(i2, N);
                    isExact[i2] = true;
                    ++numExact;
                    toAddBack.add(i2);
                }
                if (isExact[i2]) continue;
                to_pull.add(i2);
            }
            OnLineStatistics changeInStats = ParallelUtils.run(parallel, to_pull.size(), (start, end) -> {
                Random rand = (Random)localRand.get();
                OnLineStatistics localStats = new OnLineStatistics();
                for (int i_count = start; i_count < end; ++i_count) {
                    int i = to_pull.get(i_count);
                    for (int j_count = 0; j_count < samples; ++j_count) {
                        int j = rand.nextInt(N);
                        while (j == i) {
                            j = rand.nextInt(N);
                        }
                        double d_ij = dm.dist(indx_map[i], indx_map[j], X, accel);
                        localStats.add(d_ij);
                        totalDistSum.addAndGet(i, d_ij);
                        totalDistCount.incrementAndGet(i);
                        if (!symetric || isExact[j]) continue;
                        totalDistSum.addAndGet(j, d_ij);
                        totalDistCount.incrementAndGet(j);
                    }
                }
                return localStats;
            }, (a, b) -> OnLineStatistics.add(a, b));
            if (!to_pull.isEmpty()) {
                distanceStats.add(changeInStats);
            }
            double v = distanceStats.getVarance();
            lowerQ.addAll(toAddBack);
            upperQ.addAll(toAddBack);
            upperQ.removeAll(to_pull);
            Iterator iterator = to_pull.iterator();
            while (iterator.hasNext()) {
                int i3 = (Integer)iterator.next();
                int T_i = totalDistCount.get(i3);
                double c_i = Math.sqrt(2.0 * v * log2d / (double)T_i);
                lower_bound_est[i3] = totalDistSum.get(i3) / (double)T_i - c_i;
                upper_bound_est[i3] = totalDistSum.get(i3) / (double)T_i + c_i;
                lowerQ.add(i3);
                upperQ.add(i3);
            }
        }
        int bestIndex = 0;
        for (int i4 = 1; i4 < N; ++i4) {
            if (!(lower_bound_est[i4] < lower_bound_est[bestIndex])) continue;
            bestIndex = i4;
        }
        return bestIndex;
    }
}

