/*
 * Decompiled with CFR 0.152.
 */
package Catalano.MachineLearning.Classification;

import Catalano.Core.ArraysUtil;
import Catalano.MachineLearning.Classification.IClassifier;
import Catalano.MachineLearning.Dataset.DatasetClassification;
import Catalano.Math.Distances.IDivergence;
import Catalano.Math.Distances.SquaredEuclideanDistance;
import Catalano.Math.Matrix;
import Catalano.Statistics.Kernels.IMercerKernel;
import java.io.Serializable;

public class KNearestNeighbors
implements IClassifier,
Serializable {
    private int k;
    private double[][] input;
    private int[] output;
    private IDivergence divergence = new SquaredEuclideanDistance();
    private IMercerKernel kernel;
    private boolean useKernel = false;

    public int getK() {
        return this.k;
    }

    public void setK(int k) {
        this.k = Math.max(1, k);
    }

    public IMercerKernel getKernel() {
        return this.kernel;
    }

    public void setMercerKernel(IMercerKernel kernel) {
        this.kernel = kernel;
        this.useKernel = true;
    }

    public IDivergence getDistance() {
        return this.divergence;
    }

    public void setDistance(IDivergence divergence) {
        this.divergence = divergence;
        this.useKernel = false;
    }

    public KNearestNeighbors() {
        this(3);
    }

    public KNearestNeighbors(int k) {
        this(3, new SquaredEuclideanDistance());
    }

    public KNearestNeighbors(int k, IDivergence divergence) {
        this.k = k;
        this.divergence = divergence;
    }

    public KNearestNeighbors(int k, IMercerKernel kernel) {
        this.k = k;
        this.kernel = kernel;
        this.useKernel = true;
    }

    @Override
    public void Learn(DatasetClassification dataset) {
        this.Learn(dataset.getInput(), dataset.getOutput());
    }

    @Override
    public void Learn(double[][] input, int[] output) {
        this.input = input;
        this.output = output;
    }

    @Override
    public int Predict(double[] feature) {
        int i;
        int sizeF = this.input.length;
        double[] dist = new double[sizeF];
        if (this.useKernel) {
            for (i = 0; i < sizeF; ++i) {
                dist[i] = this.kernel.Function(feature, this.input[i]);
            }
        } else {
            for (i = 0; i < sizeF; ++i) {
                dist[i] = this.divergence.Compute(feature, this.input[i]);
            }
        }
        if (this.k == 1) {
            return this.output[Matrix.MinIndex(dist)];
        }
        int[] indexes = ArraysUtil.Argsort(dist, true);
        int classes = Matrix.Max(this.output) + 1;
        int[] votes = new int[classes];
        for (int i2 = 0; i2 < this.k; ++i2) {
            int n = this.output[indexes[i2]];
            votes[n] = votes[n] + 1;
        }
        return Matrix.MaxIndex(votes);
    }

    @Override
    public IClassifier clone() {
        try {
            return (IClassifier)super.clone();
        }
        catch (CloneNotSupportedException ex) {
            throw new IllegalArgumentException("Clone not supported: " + ex.getMessage());
        }
    }
}

