/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.meta;

import java.util.Arrays;
import java.util.List;
import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.ListDataSet;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.collections.list.FastArrayList;

public class Bagging
extends AbstractRegressor {
    private static final long serialVersionUID = -4111544932911087016L;
    private final List<Regressor> learningAlgorithms;
    private final int bootstrapSize;

    public Bagging(Regressor learningAlgorithm, int count) {
        this(learningAlgorithm, count, -1);
    }

    public Bagging(Regressor learningAlgorithm, int count, int bootstrapSize) {
        this.bootstrapSize = bootstrapSize;
        this.learningAlgorithms = new FastArrayList<Regressor>();
        for (int i = 0; i < count; ++i) {
            this.learningAlgorithms.add(learningAlgorithm.emptyCopy());
        }
    }

    public Bagging(Regressor ... learningAlgorithms) {
        this(-1, learningAlgorithms);
    }

    public Bagging(int bootstrapSize, Regressor ... learningAlgorithms) {
        this.bootstrapSize = bootstrapSize;
        this.learningAlgorithms = Arrays.asList(learningAlgorithms);
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        for (Regressor learningAlgorithm : this.learningAlgorithms) {
            ListDataSet bootstrap = this.bootstrapSize > 0 ? dataSet.bootstrap(this.bootstrapSize) : dataSet.bootstrap();
            learningAlgorithm.trainAll(bootstrap);
        }
    }

    @Override
    public void reset() {
        for (Regressor learningAlgorithm : this.learningAlgorithms) {
            learningAlgorithm.reset();
        }
    }

    @Override
    public Matrix predictOne(Matrix input) {
        FastArrayList<Matrix> results = new FastArrayList<Matrix>();
        for (Regressor learningAlgorithm : this.learningAlgorithms) {
            Matrix result = learningAlgorithm.predictOne(input);
            results.add(result);
        }
        Matrix all = Matrix.Factory.vertCat(results);
        Matrix mean = all.mean(Calculation.Ret.NEW, 0, true);
        return mean;
    }

    @Override
    public Regressor emptyCopy() {
        Bagging bagging = new Bagging(new Regressor[0]);
        for (Regressor learningAlgorithm : this.learningAlgorithms) {
            bagging.learningAlgorithms.add(learningAlgorithm.emptyCopy());
        }
        return bagging;
    }
}

