/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicIntegerArray;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.svm.DCDs;
import jsat.clustering.SeedSelectionMethods;
import jsat.clustering.kmeans.HamerlyKMeans;
import jsat.datatransform.DataTransform;
import jsat.distributions.Distribution;
import jsat.distributions.Uniform;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.BoundedSortedList;
import jsat.utils.DoubleList;
import jsat.utils.IntSet;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

public class RBFNet
implements Classifier,
Regressor,
DataTransform,
Parameterized {
    private static final long serialVersionUID = 5418896646203518062L;
    private int numCentroids;
    private Phase1Learner p1l;
    private Phase2Learner p2l;
    private double alpha;
    private int p;
    private DistanceMetric dm;
    private boolean normalize = true;
    private Classifier baseClassifier;
    private Regressor baseRegressor;
    private List<Double> centroidDistCache;
    private List<Vec> centroids;
    private double[] bandwidths;

    public RBFNet() {
        this(100);
    }

    public RBFNet(int numCentroids) {
        this(numCentroids, Phase1Learner.K_MEANS, Phase2Learner.NEAREST_OTHER_CENTROID_AVERAGE, 3.0, 3, (DistanceMetric)new EuclideanDistance(), new DCDs());
    }

    public RBFNet(int numCentroids, Phase1Learner cl, Phase2Learner bl, double alpha, int p, DistanceMetric dm, Classifier baseClassifier) {
        this.setNumCentroids(numCentroids);
        this.setPhase1Learner(cl);
        this.setPhase2Learner(bl);
        this.setAlpha(alpha);
        this.setP(p);
        this.setDistanceMetric(dm);
        this.baseClassifier = baseClassifier;
        if (baseClassifier instanceof Regressor) {
            this.baseRegressor = (Regressor)((Object)baseClassifier);
        }
    }

    public RBFNet(int numCentroids, Phase1Learner cl, Phase2Learner bl, double alpha, int p, DistanceMetric dm, Regressor baseRegressor) {
        this.setNumCentroids(numCentroids);
        this.setPhase1Learner(cl);
        this.setPhase2Learner(bl);
        this.setAlpha(alpha);
        this.setP(p);
        this.setDistanceMetric(dm);
        this.baseRegressor = baseRegressor;
        if (baseRegressor instanceof Classifier) {
            this.baseClassifier = (Classifier)((Object)baseRegressor);
        }
    }

    public RBFNet(RBFNet toCopy) {
        this.setNumCentroids(toCopy.getNumCentroids());
        this.setPhase1Learner(toCopy.getPhase1Learner());
        this.setPhase2Learner(toCopy.getPhase2Learner());
        this.setAlpha(toCopy.getAlpha());
        this.setP(toCopy.getP());
        this.setDistanceMetric(toCopy.getDistanceMetric().clone());
        if (toCopy.baseRegressor != null) {
            this.baseRegressor = toCopy.baseRegressor.clone();
            if (this.baseRegressor instanceof Classifier) {
                this.baseClassifier = (Classifier)((Object)this.baseRegressor);
            }
        } else if (toCopy.baseClassifier != null) {
            this.baseClassifier = toCopy.baseClassifier.clone();
            if (this.baseClassifier instanceof Regressor) {
                this.baseRegressor = (Regressor)((Object)this.baseClassifier);
            }
        }
        if (toCopy.centroids != null) {
            this.centroids = new ArrayList<Vec>(toCopy.centroids.size());
            for (Vec v : toCopy.centroids) {
                this.centroids.add(v.clone());
            }
            if (toCopy.centroidDistCache != null) {
                this.centroidDistCache = new DoubleList(toCopy.centroidDistCache);
            }
        }
        if (toCopy.bandwidths != null) {
            this.bandwidths = Arrays.copyOf(toCopy.bandwidths, toCopy.bandwidths.length);
        }
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        List<Double> qi = this.dm.getQueryInfo(x);
        Vec sv = new SparseVector(this.numCentroids);
        double sum = 0.0;
        double maxActivation = Double.NEGATIVE_INFINITY;
        int highestNeuron = -1;
        for (int i = 0; i < this.centroids.size(); ++i) {
            double sig;
            double dist = this.dm.dist(i, x, qi, this.centroids, this.centroidDistCache);
            double activation = Math.exp(-(dist * dist) / ((sig = this.bandwidths[i]) * sig * 2.0));
            if (activation > maxActivation) {
                maxActivation = activation;
                highestNeuron = i;
            }
            if (!(activation > 1.0E-16)) continue;
            sv.set(i, activation);
            sum += activation;
        }
        if (sv.nnz() == 0) {
            sv.set(highestNeuron, maxActivation);
            sum = maxActivation;
        }
        if (this.normalize && sum != 0.0) {
            sv.mutableDivide(sum);
        }
        if (sv.nnz() > sv.length() / 2) {
            sv = new DenseVector(sv);
        }
        return new DataPoint(sv, dp.getCategoricalValues(), dp.getCategoricalData(), dp.getWeight());
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || Double.isInfinite(alpha) || Double.isNaN(alpha)) {
            throw new IllegalArgumentException("Alpha must be a positive value, not " + alpha);
        }
        this.alpha = alpha;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public static Distribution guessAlpha(DataSet data) {
        return new Uniform(0.8, 3.5);
    }

    public void setP(int p) {
        if (p < 1) {
            throw new IllegalArgumentException("neighbors parameter must be positive, not " + p);
        }
        this.p = p;
    }

    public int getP() {
        return this.p;
    }

    public static Distribution guessP(DataSet data) {
        return new UniformDiscrete(2, 5);
    }

    public void setNumCentroids(int numCentroids) {
        if (numCentroids < 1) {
            throw new IllegalArgumentException("Number of centroids must be positive, not " + numCentroids);
        }
        this.numCentroids = numCentroids;
    }

    public int getNumCentroids() {
        return this.numCentroids;
    }

    public static Distribution guessNumCentroids(DataSet data) {
        return new UniformDiscrete(25, 1000);
    }

    public void setDistanceMetric(DistanceMetric dm) {
        this.dm = dm;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setPhase1Learner(Phase1Learner p1l) {
        this.p1l = p1l;
    }

    public Phase1Learner getPhase1Learner() {
        return this.p1l;
    }

    public void setPhase2Learner(Phase2Learner p2l) {
        this.p2l = p2l;
    }

    public Phase2Learner getPhase2Learner() {
        return this.p2l;
    }

    public void setNormalize(boolean normalize) {
        this.normalize = normalize;
    }

    public boolean isNormalize() {
        return this.normalize;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        return this.baseClassifier.classify(this.transform(data));
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (this.baseClassifier == null) {
            throw new FailedToFitException("RBFNet was not given a base classifier");
        }
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        this.centroids = this.p1l.getCentroids(dataSet, this.numCentroids, this.dm, parallel);
        this.centroidDistCache = this.dm.getAccelerationCache(this.centroids, parallel);
        this.bandwidths = this.p2l.estimateBandwidths(this.alpha, this.p, dataSet, this.centroids, this.centroidDistCache, this.dm, threadPool);
        ClassificationDataSet transformedData = dataSet.shallowClone();
        transformedData.applyTransform((DataTransform)this, parallel);
        this.baseClassifier.train(transformedData, parallel);
    }

    @Override
    public boolean supportsWeightedData() {
        if (this.baseClassifier != null) {
            return this.baseClassifier.supportsWeightedData();
        }
        return this.baseRegressor.supportsWeightedData();
    }

    @Override
    public double regress(DataPoint data) {
        return this.baseRegressor.regress(this.transform(data));
    }

    @Override
    public void fit(DataSet data) {
        if (data instanceof ClassificationDataSet) {
            this.train((ClassificationDataSet)data);
        } else if (data instanceof RegressionDataSet) {
            this.train((RegressionDataSet)data);
        } else {
            throw new FailedToFitException("Data must be a classifiation or regression dataset, not " + data.getClass().getSimpleName());
        }
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        if (this.baseRegressor == null) {
            throw new FailedToFitException("RBFNet was not given a base classifier");
        }
        ExecutorService threadPool = ParallelUtils.getNewExecutor(parallel);
        this.centroids = this.p1l.getCentroids(dataSet, this.numCentroids, this.dm, parallel);
        this.centroidDistCache = this.dm.getAccelerationCache(this.centroids, parallel);
        this.bandwidths = this.p2l.estimateBandwidths(this.alpha, this.p, dataSet, this.centroids, this.centroidDistCache, this.dm, threadPool);
        RegressionDataSet transformedData = dataSet.shallowClone();
        transformedData.applyTransform((DataTransform)this, parallel);
        this.baseRegressor.train(transformedData, parallel);
    }

    @Override
    public RBFNet clone() {
        return new RBFNet(this);
    }

    public static enum Phase2Learner {
        CENTROID_DISTANCE{

            @Override
            protected double[] estimateBandwidths(double alpha, int p, DataSet data, List<Vec> centroids, List<Double> centroidDistCache, DistanceMetric dm, ExecutorService threadpool) {
                double[] bandwidths = new double[centroids.size()];
                OnLineStatistics[] averages = new OnLineStatistics[bandwidths.length];
                for (int i = 0; i < averages.length; ++i) {
                    averages[i] = new OnLineStatistics();
                }
                List<Vec> X = data.getDataVectors();
                ParallelUtils.run(true, data.getSampleSize(), (start, end) -> {
                    OnLineStatistics[] localAverages = new OnLineStatistics[bandwidths.length];
                    for (int i = 0; i < localAverages.length; ++i) {
                        localAverages[i] = new OnLineStatistics();
                    }
                    for (int z = start; z < end; ++z) {
                        Vec x = (Vec)X.get(z);
                        double minDist = Double.POSITIVE_INFINITY;
                        int minI = 0;
                        for (int i = 0; i < centroids.size(); ++i) {
                            double dist = dm.dist(i, x, (List<? extends Vec>)centroids, centroidDistCache);
                            if (!(dist < minDist)) continue;
                            minDist = dist;
                            minI = i;
                        }
                        localAverages[minI].add(minDist);
                    }
                    OnLineStatistics[] onLineStatisticsArray = averages;
                    synchronized (averages) {
                        for (int i = 0; i < localAverages.length; ++i) {
                            if (localAverages[i].getSumOfWeights() == 0.0) continue;
                            averages[i] = OnLineStatistics.add(averages[i], localAverages[i]);
                        }
                        // ** MonitorExit[var9_10] (shouldn't be in output)
                        return;
                    }
                }, threadpool);
                for (int i = 0; i < bandwidths.length; ++i) {
                    bandwidths[i] = averages[i].getMean() + averages[i].getStandardDeviation() * alpha;
                }
                return bandwidths;
            }
        }
        ,
        CLOSEST_OPPOSITE_CENTROID{

            @Override
            protected double[] estimateBandwidths(double alpha, int p, DataSet data, List<Vec> centroids, List<Double> centroidDistCache, DistanceMetric dm, ExecutorService threadpool) {
                if (!(data instanceof ClassificationDataSet)) {
                    throw new FailedToFitException("CLOSEST_OPPOSITE_CENTROID only works for classification data sets");
                }
                ClassificationDataSet cds = (ClassificationDataSet)data;
                double[] bandwidths = new double[centroids.size()];
                AtomicIntegerArray[] classLabels = new AtomicIntegerArray[centroids.size()];
                for (int i = 0; i < classLabels.length; ++i) {
                    classLabels[i] = new AtomicIntegerArray(cds.getClassSize());
                }
                ParallelUtils.run(true, data.getSampleSize(), (start, end) -> {
                    for (int id = start; id < end; ++id) {
                        Vec x = cds.getDataPoint(id).getNumericalValues();
                        double minDist = Double.POSITIVE_INFINITY;
                        int minI = 0;
                        for (int i = 0; i < centroids.size(); ++i) {
                            double dist = dm.dist(i, x, (List<? extends Vec>)centroids, centroidDistCache);
                            if (!(dist < minDist)) continue;
                            minDist = dist;
                            minI = i;
                        }
                        classLabels[minI].incrementAndGet(cds.getDataPointCategory(id));
                    }
                }, threadpool);
                int[] neuronClass = new int[centroids.size()];
                for (int i = 0; i < neuronClass.length; ++i) {
                    int maxVal = -1;
                    int maxClass = 0;
                    for (int j = 0; j < classLabels[i].length(); ++j) {
                        if (classLabels[i].get(j) <= maxVal) continue;
                        maxClass = j;
                        maxVal = classLabels[i].get(j);
                    }
                    neuronClass[i] = maxClass;
                }
                ParallelUtils.run(true, centroids.size(), center -> {
                    int i;
                    double minDist = Double.POSITIVE_INFINITY;
                    for (i = 0; i < centroids.size(); ++i) {
                        if (neuronClass[center] == neuronClass[i]) continue;
                        minDist = Math.min(minDist, dm.dist(i, center, (List<? extends Vec>)centroids, centroidDistCache));
                    }
                    if (Double.isInfinite(minDist)) {
                        for (i = 0; i < centroids.size(); ++i) {
                            if (center == i) continue;
                            minDist = Math.min(minDist, dm.dist(i, center, (List<? extends Vec>)centroids, centroidDistCache));
                        }
                    }
                    bandwidths[center] = alpha * minDist;
                }, threadpool);
                return bandwidths;
            }
        }
        ,
        NEAREST_OTHER_CENTROID_AVERAGE{

            @Override
            protected double[] estimateBandwidths(double alpha, int p, DataSet data, List<Vec> centroids, List<Double> centroidDistCache, DistanceMetric dm, ExecutorService threadpool) {
                double[] bandwidths = new double[centroids.size()];
                CountDownLatch latch = new CountDownLatch(centroids.size());
                ParallelUtils.run(true, centroids.size(), center -> {
                    BoundedSortedList<Double> closestDistances = new BoundedSortedList<Double>(p);
                    for (int i = 0; i < centroids.size(); ++i) {
                        if (i == center) continue;
                        closestDistances.add(dm.dist(i, center, (List<? extends Vec>)centroids, centroidDistCache));
                    }
                    OnLineStatistics stats = new OnLineStatistics();
                    Iterator iterator = closestDistances.iterator();
                    while (iterator.hasNext()) {
                        double dist = (Double)iterator.next();
                        stats.add(dist);
                    }
                    bandwidths[center] = stats.getMean() + alpha * stats.getStandardDeviation();
                }, threadpool);
                return bandwidths;
            }
        };


        protected abstract double[] estimateBandwidths(double var1, int var3, DataSet var4, List<Vec> var5, List<Double> var6, DistanceMetric var7, ExecutorService var8);
    }

    public static enum Phase1Learner {
        RANDOM{

            @Override
            protected List<Vec> getCentroids(DataSet data, int centroids, DistanceMetric dm, boolean parallel) {
                Random rand = RandomUtil.getRandom();
                ArrayList<Vec> toRet = new ArrayList<Vec>();
                IntSet points = new IntSet();
                while (points.size() < centroids) {
                    points.add(Integer.valueOf(rand.nextInt(data.getSampleSize())));
                }
                Iterator iterator = points.iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    toRet.add(data.getDataPoint(i).getNumericalValues());
                }
                return toRet;
            }
        }
        ,
        K_MEANS{

            @Override
            protected List<Vec> getCentroids(DataSet data, int centroids, DistanceMetric dm, boolean parallel) {
                HamerlyKMeans kmeans = new HamerlyKMeans(dm, SeedSelectionMethods.SeedSelection.KPP);
                kmeans.cluster(data, centroids, parallel);
                return kmeans.getMeans();
            }
        };


        protected abstract List<Vec> getCentroids(DataSet var1, int var2, DistanceMetric var3, boolean var4);
    }
}

