package org.encog.ensemble;

import java.util.ArrayList;
import java.util.Iterator;
import org.encog.ensemble.EnsembleTypes;
import org.encog.ensemble.aggregator.WeightedAveraging;
import org.encog.ensemble.data.EnsembleDataSet;
import org.encog.ensemble.data.factories.EnsembleDataSetFactory;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;

/* loaded from: input_file:org/encog/ensemble/Ensemble.class */
public abstract class Ensemble {
    private final int DEFAULT_MAX_ITERATIONS = 2000;
    protected EnsembleDataSetFactory dataSetFactory;
    protected EnsembleTrainFactory trainFactory;
    protected EnsembleAggregator aggregator;
    protected ArrayList<EnsembleML> members;
    protected EnsembleMLMethodFactory mlFactory;
    protected MLDataSet aggregatorDataSet;

    /* loaded from: input_file:org/encog/ensemble/Ensemble$NotPossibleInThisMethod.class */
    public class NotPossibleInThisMethod extends Exception {
        private static final long serialVersionUID = 5118253806179408868L;

        public NotPossibleInThisMethod() {
        }
    }

    /* loaded from: input_file:org/encog/ensemble/Ensemble$TrainingAborted.class */
    public class TrainingAborted extends Exception {
        private static final long serialVersionUID = -5074472788684621859L;

        public TrainingAborted(String str) {
            super(str);
        }
    }

    public abstract void initMembers();

    public EnsembleML generateNewMember() {
        GenericEnsembleML genericEnsembleML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()), this.mlFactory.getLabel());
        genericEnsembleML.setTrainingSet(this.dataSetFactory.getNewDataSet());
        genericEnsembleML.setTraining(this.trainFactory.getTraining(genericEnsembleML.getMl(), genericEnsembleML.getTrainingSet()));
        return genericEnsembleML;
    }

    public void addNewMember() {
        this.members.add(generateNewMember());
    }

    public void initMembersBySplits(int i) {
        if (this.dataSetFactory == null || i <= 0 || !this.dataSetFactory.hasSource()) {
            return;
        }
        for (int i2 = 0; i2 < i; i2++) {
            GenericEnsembleML genericEnsembleML = new GenericEnsembleML(this.mlFactory.createML(this.dataSetFactory.getInputCount(), this.dataSetFactory.getOutputCount()), this.mlFactory.getLabel());
            genericEnsembleML.setTrainingSet(this.dataSetFactory.getNewDataSet());
            genericEnsembleML.setTraining(this.trainFactory.getTraining(genericEnsembleML.getMl(), genericEnsembleML.getTrainingSet()));
            this.members.add(genericEnsembleML);
        }
        if (this.aggregator.needsTraining()) {
            this.aggregatorDataSet = this.dataSetFactory.getNewDataSet();
        }
    }

    public void setTrainingMethod(EnsembleTrainFactory ensembleTrainFactory) {
        this.trainFactory = ensembleTrainFactory;
        initMembers();
    }

    public void setTrainingData(MLDataSet mLDataSet) {
        this.dataSetFactory.setInputData(mLDataSet);
        initMembers();
    }

    public void setTrainingDataFactory(EnsembleDataSetFactory ensembleDataSetFactory) {
        this.dataSetFactory = ensembleDataSetFactory;
        initMembers();
    }

    public void trainMember(int i, double d, double d2, int i2, EnsembleDataSet ensembleDataSet, boolean z) throws TrainingAborted {
        trainMember(this.members.get(i), d, d2, i2, 2000, ensembleDataSet, z);
    }

    public void trainMember(EnsembleML ensembleML, double d, double d2, int i, int i2, EnsembleDataSet ensembleDataSet, boolean z) throws TrainingAborted {
        int i3 = 0;
        do {
            long nanoTime = System.nanoTime();
            this.mlFactory.reInit(ensembleML.getMl());
            ensembleML.train(d, i, z);
            long nanoTime2 = System.nanoTime();
            if (z) {
                System.out.println("training took " + ((nanoTime2 - nanoTime) / 1.0E9d));
                System.out.println("test MSE: " + ensembleML.getError(ensembleDataSet) + " on " + ensembleDataSet.size() + " data points");
            }
            i3++;
            if (i3 > i2) {
                throw new TrainingAborted("Too many attempts at training ensemble member");
            }
        } while (ensembleML.getError(ensembleDataSet) > d2);
    }

    public void trainMember(EnsembleML ensembleML, double d, double d2, EnsembleDataSet ensembleDataSet, boolean z) throws TrainingAborted {
        trainMember(ensembleML, d, d2, 2000, 2000, ensembleDataSet, z);
    }

    public void trainMember(int i, double d, double d2, EnsembleDataSet ensembleDataSet, boolean z) throws TrainingAborted {
        trainMember(i, d, d2, 2000, ensembleDataSet, z);
    }

    public void retrainAggregator() {
        EnsembleDataSet ensembleDataSet = new EnsembleDataSet(this.members.size() * this.aggregatorDataSet.getIdealSize(), this.aggregatorDataSet.getIdealSize());
        for (MLDataPair mLDataPair : this.aggregatorDataSet) {
            BasicMLData basicMLData = new BasicMLData(this.members.size() * this.aggregatorDataSet.getIdealSize());
            int i = 0;
            Iterator<EnsembleML> it = this.members.iterator();
            while (it.hasNext()) {
                for (double d : it.next().compute(mLDataPair.getInput()).getData()) {
                    int i2 = i;
                    i++;
                    basicMLData.add(i2, d);
                }
            }
            ensembleDataSet.add(basicMLData, mLDataPair.getIdeal());
        }
        this.aggregator.setTrainingSet(ensembleDataSet);
        this.aggregator.train();
    }

    public void train(double d, double d2, int i, int i2, EnsembleDataSet ensembleDataSet, boolean z) throws TrainingAborted {
        Iterator<EnsembleML> it = this.members.iterator();
        while (it.hasNext()) {
            trainMember(it.next(), d, d2, i, i2, ensembleDataSet, z);
        }
        if (this.aggregator.needsTraining()) {
            retrainAggregator();
        }
    }

    public void train(double d, double d2, EnsembleDataSet ensembleDataSet, boolean z) throws TrainingAborted {
        train(d, d2, 2000, 2000, ensembleDataSet, z);
    }

    public void train(double d, double d2, EnsembleDataSet ensembleDataSet) throws TrainingAborted {
        train(d, d2, ensembleDataSet, false);
    }

    public void train(double d, double d2, int i, EnsembleDataSet ensembleDataSet) throws TrainingAborted {
        train(d, d2, i, 2000, ensembleDataSet, false);
    }

    public MLDataSet getTrainingSet(int i) {
        return this.members.get(i).getTrainingSet();
    }

    public EnsembleML getMember(int i) {
        return this.members.get(i);
    }

    public void addMember(EnsembleML ensembleML) throws NotPossibleInThisMethod {
        this.members.add(ensembleML);
    }

    public MLData compute(MLData mLData) throws WeightedAveraging.WeightMismatchException {
        ArrayList<MLData> arrayList = new ArrayList<>();
        Iterator<EnsembleML> it = this.members.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().compute(mLData));
        }
        return this.aggregator.evaluate(arrayList);
    }

    public EnsembleAggregator getAggregator() {
        return this.aggregator;
    }

    public void setAggregator(EnsembleAggregator ensembleAggregator) {
        this.aggregator = ensembleAggregator;
    }

    public abstract EnsembleTypes.ProblemType getProblemType();
}
