package org.jdmp.core.algorithm.classification;

import java.util.Collections;
import java.util.Iterator;
import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.collections.list.FastArrayList;
import org.ujmp.core.util.Sortable;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/KNNClassifier.class */
public class KNNClassifier extends AbstractRegressor {
    private static final long serialVersionUID = 5971192321313837066L;
    private final int k;
    private ListDataSet dataSet = null;

    public KNNClassifier(int i) {
        this.k = i;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Regressor emptyCopy() {
        return new KNNClassifier(this.k);
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void trainAll(ListDataSet listDataSet) {
        this.dataSet = listDataSet;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void reset() {
        this.dataSet = null;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Matrix predictOne(Matrix matrix) {
        FastArrayList fastArrayList = new FastArrayList();
        for (Sample sample : this.dataSet) {
            double euklideanDistanceTo = matrix.euklideanDistanceTo(sample.getAsMatrix(getInputLabel()), true);
            if (fastArrayList.size() < this.k) {
                fastArrayList.add(new Sortable(Double.valueOf(euklideanDistanceTo), sample.getAsMatrix(getTargetLabel())));
                Collections.sort(fastArrayList);
            } else if (euklideanDistanceTo < ((Double) ((Sortable) fastArrayList.get(this.k - 1)).getComparable()).doubleValue()) {
                fastArrayList.remove(this.k - 1);
                fastArrayList.add(new Sortable(Double.valueOf(euklideanDistanceTo), sample.getAsMatrix(getTargetLabel())));
                Collections.sort(fastArrayList);
            }
        }
        FastArrayList fastArrayList2 = new FastArrayList();
        Iterator it = fastArrayList.iterator();
        while (it.hasNext()) {
            fastArrayList2.add(((Matrix) ((Sortable) it.next()).getObject()).toColumnVector(Calculation.Ret.LINK));
        }
        return Matrix.Factory.vertCat(fastArrayList2).mean(Calculation.Ret.NEW, 0, true);
    }
}
