/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.meta;

import java.util.ArrayList;
import java.util.List;
import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.dataset.DataSet;
import org.jdmp.core.dataset.ListDataSet;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

public class MultiClassClassifier
extends AbstractClassifier {
    private static final long serialVersionUID = 466059743021340944L;
    private Classifier singleClassClassifier = null;
    private int classCount = 0;
    private boolean twoColumns = false;
    private final List<Classifier> singleClassClassifiers = new ArrayList<Classifier>();

    public MultiClassClassifier(Classifier singleClassClassifier, boolean twoColumns) {
        this.setLabel("MultiClassClassifier [" + singleClassClassifier.toString() + "]");
        this.singleClassClassifier = singleClassClassifier;
        this.twoColumns = twoColumns;
    }

    @Override
    public Matrix predictOne(Matrix input) {
        double[] results = new double[this.classCount];
        for (int i = 0; i < this.classCount; ++i) {
            Classifier c = this.singleClassClassifiers.get(i);
            results[i] = c.predictOne(input).getAsDouble(0L, 0L);
        }
        return Matrix.Factory.linkToArray(results).transpose();
    }

    @Override
    public void reset() {
        this.singleClassClassifiers.clear();
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        this.reset();
        this.classCount = this.getClassCount(dataSet);
        for (int i = 0; i < this.classCount; ++i) {
            System.out.println("Training class " + i);
            Classifier c = this.singleClassClassifier.emptyCopy();
            this.singleClassClassifiers.add(c);
            Matrix input = dataSet.getInputMatrix();
            Matrix target = dataSet.getTargetMatrix().selectColumns(Calculation.Ret.LINK, i);
            if (this.twoColumns) {
                Matrix target2 = target.minus(1.0).abs(Calculation.Ret.NEW);
                target = Matrix.Factory.horCat(target, target2);
            }
            ListDataSet ds = DataSet.Factory.linkToInputAndTarget(input, target);
            c.trainAll(ds);
        }
    }

    @Override
    public Classifier emptyCopy() {
        return new MultiClassClassifier(this.singleClassClassifier, this.twoColumns);
    }
}

