/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification;

import java.util.Collections;
import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.ListDataSet;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.collections.list.FastArrayList;
import org.ujmp.core.util.Sortable;

public class KNNClassifier
extends AbstractRegressor {
    private static final long serialVersionUID = 5971192321313837066L;
    private final int k;
    private ListDataSet dataSet = null;

    public KNNClassifier(int k) {
        this.k = k;
    }

    @Override
    public Regressor emptyCopy() {
        return new KNNClassifier(this.k);
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        this.dataSet = dataSet;
    }

    @Override
    public void reset() {
        this.dataSet = null;
    }

    @Override
    public Matrix predictOne(Matrix input) {
        FastArrayList<Sortable<Double, Matrix>> bestResults = new FastArrayList<Sortable<Double, Matrix>>();
        for (Object s : this.dataSet) {
            Matrix matrix = s.getAsMatrix(this.getInputLabel());
            double distance = input.euklideanDistanceTo(matrix, true);
            if (bestResults.size() < this.k) {
                bestResults.add(new Sortable<Double, Matrix>(distance, s.getAsMatrix(this.getTargetLabel())));
                Collections.sort(bestResults);
                continue;
            }
            if (!(distance < (Double)((Sortable)bestResults.get(this.k - 1)).getComparable())) continue;
            bestResults.remove(this.k - 1);
            bestResults.add(new Sortable<Double, Matrix>(distance, s.getAsMatrix(this.getTargetLabel())));
            Collections.sort(bestResults);
        }
        FastArrayList<Matrix> results = new FastArrayList<Matrix>();
        for (Sortable sortable : bestResults) {
            results.add(((Matrix)sortable.getObject()).toColumnVector(Calculation.Ret.LINK));
        }
        Matrix resultMatrix = Matrix.Factory.vertCat(results);
        Matrix matrix = resultMatrix.mean(Calculation.Ret.NEW, 0, true);
        return matrix;
    }
}

