/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.classification.ClassifierTrainer;
import smile.classification.SoftClassifier;
import smile.math.Math;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.distance.Metric;
import smile.neighbor.CoverTree;
import smile.neighbor.KDTree;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.NearestNeighborSearch;
import smile.neighbor.Neighbor;

public class KNN<T>
implements SoftClassifier<T>,
Serializable {
    private static final long serialVersionUID = 1L;
    private KNNSearch<T, T> knn;
    private int[] y;
    private int k;
    private int c;

    public KNN(KNNSearch<T, T> knn, int[] y, int k) {
        this.knn = knn;
        this.k = k;
        this.y = y;
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.c = labels.length;
        if (this.c < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
    }

    public KNN(T[] x, int[] y, Distance<T> distance) {
        this(x, y, distance, 1);
    }

    public KNN(T[] x, int[] y, Distance<T> distance, int k) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (k < 1) {
            throw new IllegalArgumentException("Illegal k = " + k);
        }
        int[] labels = Math.unique(y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.c = labels.length;
        if (this.c < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        this.y = y;
        this.k = k;
        this.knn = distance instanceof Metric ? new CoverTree<T>(x, (Metric)distance) : new LinearSearch<T>(x, distance);
    }

    public static KNN<double[]> learn(double[][] x, int[] y) {
        return KNN.learn(x, y, 1);
    }

    public static KNN<double[]> learn(double[][] x, int[] y, int k) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (k < 1) {
            throw new IllegalArgumentException("Illegal k = " + k);
        }
        NearestNeighborSearch knn = null;
        knn = x[0].length < 10 ? new KDTree(x, (E[])x) : new CoverTree<double[]>((E[])x, new EuclideanDistance());
        return new KNN<double[]>((KNNSearch<double[], double[]>)((Object)knn), y, k);
    }

    @Override
    public int predict(T x) {
        return this.predict(x, null);
    }

    @Override
    public int predict(T x, double[] posteriori) {
        int i;
        if (posteriori != null && posteriori.length != this.c) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.c));
        }
        Neighbor<T, T>[] neighbors = this.knn.knn(x, this.k);
        if (this.k == 1) {
            return this.y[neighbors[0].index];
        }
        int[] count = new int[this.c];
        for (i = 0; i < this.k; ++i) {
            int n = this.y[neighbors[i].index];
            count[n] = count[n] + 1;
        }
        if (posteriori != null) {
            for (i = 0; i < this.c; ++i) {
                posteriori[i] = (double)count[i] / (double)this.k;
            }
        }
        int max = 0;
        int idx = 0;
        for (int i2 = 0; i2 < this.c; ++i2) {
            if (count[i2] <= max) continue;
            max = count[i2];
            idx = i2;
        }
        return idx;
    }

    public static class Trainer<T>
    extends ClassifierTrainer<T> {
        private int k;
        private Distance<T> distance;

        public Trainer(Distance<T> distance, int k) {
            if (k < 1) {
                throw new IllegalArgumentException("Invalid k of k-NN: " + k);
            }
            this.distance = distance;
            this.k = k;
        }

        @Override
        public KNN<T> train(T[] x, int[] y) {
            return new KNN<T>(x, y, this.distance, this.k);
        }
    }
}

