/*
 * Decompiled with CFR 0.152.
 */
package jsat.clustering.hierarchical;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.clustering.ClustererBase;
import jsat.clustering.KClusterer;
import jsat.clustering.dissimilarity.LanceWilliamsDissimilarity;
import jsat.clustering.dissimilarity.WardsDissimilarity;
import jsat.clustering.hierarchical.PriorityHAC;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.utils.IndexTable;
import jsat.utils.IntDoubleMap;
import jsat.utils.IntDoubleMapArray;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.concurrent.ParallelUtils;

public class NNChainHAC
implements KClusterer {
    private LanceWilliamsDissimilarity distMeasure;
    private DistanceMetric dm;
    private int[] merges;

    public NNChainHAC() {
        this(new WardsDissimilarity());
    }

    public NNChainHAC(LanceWilliamsDissimilarity distMeasure) {
        this(distMeasure, new EuclideanDistance());
    }

    public NNChainHAC(LanceWilliamsDissimilarity distMeasure, DistanceMetric distance) {
        this.distMeasure = distMeasure;
        this.dm = distance;
    }

    protected NNChainHAC(NNChainHAC toCopy) {
        this.distMeasure = toCopy.distMeasure.clone();
        this.dm = toCopy.dm.clone();
        if (toCopy.merges != null) {
            this.merges = Arrays.copyOf(toCopy.merges, toCopy.merges.length);
        }
    }

    @Override
    public NNChainHAC clone() {
        return new NNChainHAC(this);
    }

    @Override
    public int[] cluster(DataSet dataSet, boolean parallel, int[] designations) {
        return this.cluster(dataSet, 2, (int)Math.sqrt(dataSet.getSampleSize()), parallel, designations);
    }

    private double getDist(int a, int j, int[] size, List<Vec> vecs, List<Double> cache, List<Map<Integer, Double>> d_xk) {
        if (size[j] == 1 && size[a] == 1) {
            return this.dm.dist(a, j, vecs, cache);
        }
        if (d_xk.get(a) != null) {
            Double tmp = d_xk.get(a).get(j);
            if (tmp != null) {
                return tmp;
            }
            return d_xk.get(j).get(a);
        }
        return d_xk.get(j).get(a);
    }

    public int[] getClusterDesignations(int[] designations, int clusters) {
        if (this.merges == null) {
            return null;
        }
        return PriorityHAC.assignClusterDesignations(designations, clusters, this.merges);
    }

    public List<List<DataPoint>> getClusterDesignations(int clusters, DataSet data) {
        if (this.merges == null || (this.merges.length + 2) / 2 != data.getSampleSize()) {
            return null;
        }
        int[] assignments = new int[data.getSampleSize()];
        assignments = this.getClusterDesignations(assignments, clusters);
        return ClustererBase.createClusterListFromAssignmentArray(assignments, data);
    }

    @Override
    public int[] cluster(DataSet dataSet, int clusters, boolean parallel, int[] designations) {
        return this.cluster(dataSet, clusters, clusters, parallel, designations);
    }

    @Override
    public int[] cluster(DataSet dataSet, int lowK, int highK, boolean parallel, int[] designations) {
        if (designations == null) {
            designations = new int[dataSet.getSampleSize()];
        }
        int N = dataSet.getSampleSize();
        this.merges = new int[N * 2 - 2];
        IntList merge_removed = new IntList(N);
        IntList merge_kept = new IntList(N);
        int[] size = new int[N];
        Arrays.fill(size, 1);
        double[] mergedDistance = new double[N - 1];
        int L_pos = 0;
        IntList S = new IntList(N);
        ListUtils.addRange(S, 0, N, 1);
        ArrayList<Map<Integer, Double>> dist_map = new ArrayList<Map<Integer, Double>>(N);
        for (int i = 0; i < N; ++i) {
            dist_map.add(null);
        }
        List<Vec> vecs = dataSet.getDataVectors();
        List<Double> cache = this.dm.getAccelerationCache(vecs, parallel);
        int[] chain = new int[N];
        int chainPos = 0;
        while (S.size() > 1) {
            AbstractMap map_n;
            boolean singleThread;
            double dist_ab;
            int b;
            int a;
            if (chainPos <= 3) {
                a = S.getI(0);
                chainPos = 0;
                chain[chainPos++] = a;
                b = S.getI(1);
            } else {
                a = chain[chainPos - 4];
                b = chain[chainPos - 3];
                chainPos -= 3;
            }
            do {
                AtomicInteger c = new AtomicInteger(b);
                AtomicDouble minDist = new AtomicDouble(this.getDist(a, c.get(), size, vecs, cache, dist_map));
                int a_ = a;
                int c_ = c.get();
                boolean doPara = parallel && S.size() > SystemInfo.LogicalCores * 2 && S.size() >= 100;
                ParallelUtils.run(doPara, S.size(), (start, end) -> {
                    double local_minDist = Double.POSITIVE_INFINITY;
                    int local_c = S.get(start);
                    for (int i = start; i < end; ++i) {
                        double dist;
                        int j = S.getI(i);
                        if (j == a_ || j == c_ || !((dist = this.getDist(a_, j, size, vecs, cache, dist_map)) < local_minDist)) continue;
                        local_minDist = dist;
                        local_c = j;
                    }
                    AtomicDouble atomicDouble = minDist;
                    synchronized (atomicDouble) {
                        if (local_minDist < minDist.get()) {
                            minDist.set(local_minDist);
                            c.set(local_c);
                        }
                    }
                });
                dist_ab = minDist.get();
                b = a;
                a = c.get();
                chain[chainPos++] = a;
            } while (chainPos < 3 || a != chain[chainPos - 3]);
            int n = Math.min(a, b);
            int removed = Math.max(a, b);
            merge_removed.add(removed);
            merge_kept.add(n);
            mergedDistance[L_pos] = dist_ab;
            ++L_pos;
            S.removeAll(Arrays.asList(a, b));
            for (int i = Math.max(0, chainPos - 5); i < chainPos; ++i) {
                if (chain[i] != removed) continue;
                chain[i] = n;
            }
            int size_a = size[a];
            int size_b = size[b];
            boolean bl = singleThread = !parallel || S.size() <= SystemInfo.LogicalCores * 10;
            if (S.isEmpty()) {
                map_n = null;
            } else if (S.size() * 100 >= N || !singleThread) {
                map_n = new IntDoubleMapArray(N);
            } else {
                map_n = new IntDoubleMap(S.size());
                Iterator iterator = S.iterator();
                while (iterator.hasNext()) {
                    int x2 = (Integer)iterator.next();
                    map_n.put(x2, -0.0);
                }
            }
            int a_ = a;
            int b_ = b;
            double dist_ab_ = dist_ab;
            ParallelUtils.streamP(S.streamInts(), !singleThread).forEach(x -> {
                double d_ax = this.getDist(a_, x, size, vecs, cache, dist_map);
                double d_bx = this.getDist(b_, x, size, vecs, cache, dist_map);
                double d_xn = this.distMeasure.dissimilarity(size_a, size_b, size[x], dist_ab_, d_ax, d_bx);
                Map dist_map_x = (Map)dist_map.get(x);
                if (dist_map_x != null) {
                    dist_map_x.remove(b_);
                    dist_map_x.put(n, d_xn);
                    if (dist_map_x.size() * 50 < N && !(dist_map_x instanceof IntDoubleMap)) {
                        dist_map.set(x, new IntDoubleMap(dist_map_x));
                    }
                }
                map_n.put(x, d_xn);
            });
            dist_map.set(removed, null);
            dist_map.set(n, map_n);
            size[n] = size_a + size_b;
            S.add(n);
        }
        this.fixMergeOrderAndAssign(mergedDistance, merge_kept, merge_removed, lowK, N, highK, designations);
        return designations;
    }

    private void fixMergeOrderAndAssign(double[] mergedDistance, IntList merge_kept, IntList merge_removed, int lowK, int N, int highK, int[] designations) {
        IndexTable it = new IndexTable(mergedDistance);
        it.apply(merge_kept);
        it.apply(merge_removed);
        it.apply(mergedDistance);
        for (int i = 0; i < it.length(); ++i) {
            this.merges[this.merges.length - i * 2 - 1] = merge_removed.get(i);
            this.merges[this.merges.length - i * 2 - 2] = merge_kept.get(i);
        }
        OnLineStatistics distChange = new OnLineStatistics();
        double maxStndDevs = Double.MIN_VALUE;
        int clusterSize = lowK;
        for (int i = 0; i < mergedDistance.length; ++i) {
            double stndDevs;
            distChange.add(mergedDistance[i]);
            int curK = N - i;
            if (curK < lowK || curK > highK || !((stndDevs = (mergedDistance[i] - distChange.getMean()) / distChange.getStandardDeviation()) > maxStndDevs)) continue;
            maxStndDevs = stndDevs;
            clusterSize = curK;
        }
        PriorityHAC.assignClusterDesignations(designations, clusterSize, this.merges);
    }
}

