/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.knn;

import java.util.ArrayList;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollection;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.math.MathTricks;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;

public class NearestNeighbour
implements Classifier,
Regressor,
Parameterized {
    private static final long serialVersionUID = 4239569189624285932L;
    private int k;
    private boolean weighted;
    private DistanceMetric distanceMetric;
    private CategoricalData predicting;
    private VectorCollection<VecPaired<Vec, Double>> vecCollection;
    Mode mode = null;

    public int getNeighbors() {
        return this.k;
    }

    public void setNeighbors(int k) {
        if (k < 1) {
            throw new ArithmeticException("Must be a positive number of neighbors");
        }
        this.k = k;
    }

    public int getNeighbors(int k) {
        return k;
    }

    public DistanceMetric getDistanceMetric() {
        return this.distanceMetric;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        if (distanceMetric == null) {
            throw new NullPointerException("given metric was null");
        }
        this.distanceMetric = distanceMetric;
    }

    public NearestNeighbour(int k) {
        this(k, false);
    }

    public NearestNeighbour(int k, VectorCollection<VecPaired<Vec, Double>> vcf) {
        this(k, false, new EuclideanDistance(), vcf);
    }

    public NearestNeighbour(int k, boolean weighted) {
        this(k, weighted, new EuclideanDistance());
    }

    public NearestNeighbour(int k, boolean weighted, DistanceMetric distanceMetric) {
        this(k, weighted, distanceMetric, new DefaultVectorCollection<VecPaired<Vec, Double>>());
    }

    public NearestNeighbour(int k, boolean weighted, DistanceMetric distanceMetric, VectorCollection<VecPaired<Vec, Double>> vcf) {
        this.vecCollection = vcf;
        this.k = k;
        this.weighted = weighted;
        this.distanceMetric = distanceMetric;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.vecCollection == null || this.mode != Mode.CLASSIFICATION) {
            throw new UntrainedModelException("Classifier has not been trained for classification");
        }
        Vec query = data.getNumericalValues();
        List<VecPaired<VecPaired<Vec, Double>, Double>> knns = this.vecCollection.search(query, this.k);
        if (this.weighted) {
            double[] dists = new double[knns.size()];
            for (int i = 0; i < knns.size(); ++i) {
                dists[i] = knns.get(i).getPair();
            }
            double offset = MathTricks.min(dists) * 0.1 + 1.0E-15;
            for (int i = 0; i < knns.size(); ++i) {
                dists[i] = 1.0 / (offset + dists[i]);
            }
            MathTricks.softmax(dists, false);
            CategoricalResults results = new CategoricalResults(this.predicting.getNumOfCategories());
            for (int i = 0; i < knns.size(); ++i) {
                VecPaired<Vec, Double> pm = knns.get(i).getVector();
                int index = (int)Math.round(pm.getPair());
                results.setProb(index, results.getProb(index) + dists[i]);
            }
            return results;
        }
        CategoricalResults results = new CategoricalResults(this.predicting.getNumOfCategories());
        for (int i = 0; i < knns.size(); ++i) {
            VecPaired<Vec, Double> pm = knns.get(i).getVector();
            int index = (int)Math.round(pm.getPair());
            results.setProb(index, results.getProb(index) + 1.0);
        }
        results.normalize();
        return results;
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("KNN requires vector data only");
        }
        this.mode = Mode.CLASSIFICATION;
        this.predicting = dataSet.getPredicting();
        ArrayList<VecPaired<Vec, Double>> dataPoints = new ArrayList<VecPaired<Vec, Double>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getClassSize(); ++i) {
            for (DataPoint dp : dataSet.getSamples(i)) {
                dataPoints.add(new VecPaired<Vec, Double>(dp.getNumericalValues(), Double.valueOf(i)));
            }
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, (DataSet)dataSet, parallel);
        this.vecCollection.build(parallel, dataPoints, this.distanceMetric);
    }

    @Override
    public double regress(DataPoint data) {
        if (this.vecCollection == null || this.mode != Mode.REGRESSION) {
            throw new UntrainedModelException("Classifier has not been trained for regression");
        }
        Vec query = data.getNumericalValues();
        List<VecPaired<VecPaired<Vec, Double>, Double>> knns = this.vecCollection.search(query, this.k);
        double result = 0.0;
        double weightSum = 0.0;
        for (int i = 0; i < knns.size(); ++i) {
            double distance = knns.get(i).getPair();
            VecPaired<Vec, Double> pm = knns.get(i).getVector();
            double value = pm.getPair();
            if (this.weighted) {
                distance = Math.max(1.0E-8, distance);
                double weight = 1.0 / Math.pow(distance, 2.0);
                weightSum += weight;
                result += value * weight;
                continue;
            }
            result += value;
            weightSum += 1.0;
        }
        return result / weightSum;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        if (dataSet.getNumCategoricalVars() != 0) {
            throw new FailedToFitException("KNN requires vector data only");
        }
        this.mode = Mode.REGRESSION;
        ArrayList<VecPaired<Vec, Double>> dataPoints = new ArrayList<VecPaired<Vec, Double>>(dataSet.getSampleSize());
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            DataPointPair<Double> dpp = dataSet.getDataPointPair(i);
            dataPoints.add(new VecPaired<Vec, Double>(dpp.getVector(), dpp.getPair()));
        }
        TrainableDistanceMetric.trainIfNeeded(this.distanceMetric, (DataSet)dataSet, parallel);
        this.vecCollection.build(parallel, dataPoints, this.distanceMetric);
    }

    @Override
    public NearestNeighbour clone() {
        NearestNeighbour clone = new NearestNeighbour(this.k, this.weighted, this.distanceMetric.clone(), this.vecCollection.clone());
        if (this.predicting != null) {
            clone.predicting = this.predicting.clone();
        }
        clone.mode = this.mode;
        if (this.vecCollection != null) {
            clone.vecCollection = this.vecCollection.clone();
        }
        return clone;
    }

    public static Distribution guessNeighbors(DataSet d) {
        return new UniformDiscrete(1, 25);
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    private static enum Mode {
        REGRESSION,
        CLASSIFICATION;

    }
}

