package org.jdmp.core.algorithm.classification.meta;

import java.util.ArrayList;
import java.util.List;
import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.dataset.DataSet;
import org.jdmp.core.dataset.ListDataSet;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/meta/MultiClassClassifier.class */
public class MultiClassClassifier extends AbstractClassifier {
    private static final long serialVersionUID = 466059743021340944L;
    private Classifier singleClassClassifier;
    private boolean twoColumns;
    private int classCount = 0;
    private final List<Classifier> singleClassClassifiers = new ArrayList();

    public MultiClassClassifier(Classifier classifier, boolean z) {
        this.singleClassClassifier = null;
        this.twoColumns = false;
        setLabel("MultiClassClassifier [" + classifier.toString() + "]");
        this.singleClassClassifier = classifier;
        this.twoColumns = z;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Matrix predictOne(Matrix matrix) {
        double[] dArr = new double[this.classCount];
        for (int i = 0; i < this.classCount; i++) {
            dArr[i] = this.singleClassClassifiers.get(i).predictOne(matrix).getAsDouble(new long[]{0, 0});
        }
        return Matrix.Factory.linkToArray(dArr).transpose();
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void reset() {
        this.singleClassClassifiers.clear();
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void trainAll(ListDataSet listDataSet) {
        reset();
        this.classCount = getClassCount(listDataSet);
        for (int i = 0; i < this.classCount; i++) {
            System.out.println("Training class " + i);
            Classifier emptyCopy = this.singleClassClassifier.emptyCopy();
            this.singleClassClassifiers.add(emptyCopy);
            Matrix inputMatrix = listDataSet.getInputMatrix();
            Matrix selectColumns = listDataSet.getTargetMatrix().selectColumns(Calculation.Ret.LINK, new long[]{i});
            if (this.twoColumns) {
                selectColumns = Matrix.Factory.horCat(new Matrix[]{selectColumns, selectColumns.minus(1.0d).abs(Calculation.Ret.NEW)});
            }
            emptyCopy.trainAll(DataSet.Factory.linkToInputAndTarget(inputMatrix, selectColumns));
        }
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Classifier emptyCopy() {
        return new MultiClassClassifier(this.singleClassClassifier, this.twoColumns);
    }
}
