package org.neuroph.contrib.autotrain;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.util.TransferFunctionType;

/* loaded from: input_file:org/neuroph/contrib/autotrain/AutoTrainer.class */
public class AutoTrainer {
    private static final Logger LOGGER = Logger.getLogger(AutoTrainer.class.getName());
    private List<TrainingSettings> trainingSettingsList;
    private List<TrainingResult> results;
    private int maxHiddenNeurons;
    private int minHiddenNeurons;
    private int hiddenNeuronsStep;
    private double minLearningRate;
    private double maxLearningRate;
    private double learningRateStep;
    private double maxMomentum;
    private int splitPercentage;
    private boolean splitTrainTest;
    private double maxErrorMin;
    private double maxErrorMax;
    private double maxErrorStep;
    private int maxIterations;
    private TransferFunctionType transferFunction;
    private boolean generateStatistics;
    private int repeat;

    public AutoTrainer() {
        this.hiddenNeuronsStep = 1;
        this.learningRateStep = 0.1d;
        this.maxMomentum = 0.9d;
        this.splitPercentage = 100;
        this.splitTrainTest = false;
        this.maxErrorStep = 0.01d;
        this.generateStatistics = false;
        this.repeat = 1;
        this.trainingSettingsList = new ArrayList();
        this.results = new ArrayList();
        this.transferFunction = TransferFunctionType.SIGMOID;
    }

    public AutoTrainer(double d, int i) {
        this();
        this.maxErrorMin = d;
        this.maxIterations = i;
    }

    public List<TrainingResult> getResults() {
        return this.results;
    }

    public AutoTrainer setHiddenNeurons(int i, int i2, int i3) {
        this.minHiddenNeurons = i;
        this.maxHiddenNeurons = i2;
        this.hiddenNeuronsStep = i3;
        return this;
    }

    public AutoTrainer setHiddenNeurons(int i, int i2) {
        this.minHiddenNeurons = i;
        this.maxHiddenNeurons = i2;
        return this;
    }

    public AutoTrainer setHiddenNeurons(Range range, int i) {
        this.minHiddenNeurons = (int) range.getMin();
        this.maxHiddenNeurons = (int) range.getMax();
        this.hiddenNeuronsStep = i;
        return this;
    }

    public AutoTrainer setHiddenNeurons(Range range) {
        this.minHiddenNeurons = (int) range.getMin();
        this.maxHiddenNeurons = (int) range.getMax();
        return this;
    }

    public AutoTrainer setMaxError(double d, double d2, double d3) {
        this.maxErrorMin = d;
        this.maxErrorMax = d2;
        this.maxErrorStep = d3;
        return this;
    }

    public AutoTrainer setMaxError(double d, double d2) {
        this.maxErrorMin = d;
        this.maxErrorMax = d2;
        return this;
    }

    public AutoTrainer setMaxError(Range range, double d) {
        this.maxErrorMin = range.getMin();
        this.maxErrorMax = range.getMax();
        this.maxErrorStep = d;
        return this;
    }

    public AutoTrainer setMaxError(Range range) {
        this.maxErrorMin = range.getMin();
        this.maxErrorMax = range.getMax();
        return this;
    }

    public AutoTrainer setLearningRate(double d, double d2, double d3) {
        this.minLearningRate = d;
        this.maxLearningRate = d2;
        this.learningRateStep = d3;
        return this;
    }

    public AutoTrainer setLearningRate(double d, double d2) {
        this.minLearningRate = d;
        this.maxLearningRate = d2;
        return this;
    }

    public AutoTrainer setLearningRate(Range range, double d) {
        this.minLearningRate = range.getMin();
        this.maxLearningRate = range.getMax();
        this.learningRateStep = d;
        return this;
    }

    public AutoTrainer setLearningRate(Range range) {
        this.minLearningRate = range.getMin();
        this.maxLearningRate = range.getMax();
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public AutoTrainer setMaxIterations(int i) {
        this.maxIterations = i;
        return this;
    }

    public TransferFunctionType getTransferFunction() {
        return this.transferFunction;
    }

    public AutoTrainer setTransferFunction(TransferFunctionType transferFunctionType) {
        this.transferFunction = transferFunctionType;
        return this;
    }

    public void setMaxMomentum(double d) {
        this.maxMomentum = d;
    }

    public AutoTrainer repeat(int i) {
        this.repeat = i;
        this.generateStatistics = true;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean generatesStatistics() {
        return this.generateStatistics;
    }

    public AutoTrainer setTrainTestSplit(int i) {
        this.splitPercentage = i;
        this.splitTrainTest = true;
        return this;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isSplitForTesting() {
        return this.splitTrainTest;
    }

    private void generateTrainingSettings() {
        double d = this.minLearningRate;
        int i = this.minHiddenNeurons;
        while (true) {
            int i2 = i;
            if (i2 > this.maxHiddenNeurons) {
                LOGGER.log(Level.INFO, "Generated : {0} settings.", Integer.valueOf(this.trainingSettingsList.size()));
                return;
            }
            double d2 = this.minLearningRate;
            while (true) {
                double d3 = d2;
                if (d3 <= this.maxLearningRate) {
                    double d4 = this.maxErrorMin;
                    while (true) {
                        double d5 = d4;
                        if (d5 <= this.maxErrorMax) {
                            this.trainingSettingsList.add(new TrainingSettings().setHiddenNeurons(i2).setLearningRate(d3).setMaxError(d5).setMaxIterations(getMaxIterations()));
                            d4 = d5 + this.maxErrorStep;
                        }
                    }
                    d2 = d3 + this.learningRateStep;
                }
            }
            this.minLearningRate = d;
            i = i2 + this.hiddenNeuronsStep;
        }
    }

    public void train(DataSet dataSet) {
        DataSet dataSet2;
        generateTrainingSettings();
        if (this.splitTrainTest) {
            List<DataSet> split = dataSet.split(this.splitPercentage, 100 - this.splitPercentage);
            dataSet2 = split.get(0);
            split.get(1);
        } else {
            dataSet2 = dataSet;
        }
        ArrayList arrayList = this.generateStatistics ? new ArrayList() : null;
        int i = 0;
        for (TrainingSettings trainingSettings : this.trainingSettingsList) {
            System.out.println("-----------------------------------------------------------------------------------");
            i++;
            System.out.println("##TRAINING: " + i);
            trainingSettings.setTrainingSet(this.splitPercentage);
            trainingSettings.setTestSet(100 - this.splitPercentage);
            for (int i2 = 1; i2 <= this.repeat; i2++) {
                System.out.println("#SubTraining: " + i2);
                MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(dataSet.getInputSize(), trainingSettings.getHiddenNeurons(), dataSet.getOutputSize());
                BackPropagation learningRule = multiLayerPerceptron.getLearningRule();
                learningRule.setLearningRate(trainingSettings.getLearningRate());
                learningRule.setMaxError(trainingSettings.getMaxError());
                learningRule.setMaxIterations(trainingSettings.getMaxIterations());
                multiLayerPerceptron.learn(dataSet2);
                TrainingResult trainingResult = new TrainingResult(trainingSettings, learningRule.getTotalNetworkError(), learningRule.getCurrentIteration(), new ConfusionMatrix(new String[]{""}));
                System.out.println(i2 + ") iterations: " + learningRule.getCurrentIteration());
                if (this.generateStatistics) {
                    arrayList.add(trainingResult);
                } else {
                    this.results.add(trainingResult);
                }
            }
            if (this.generateStatistics) {
                this.results.add(calculateTrainingStatistics(trainingSettings, arrayList));
                arrayList.clear();
            }
        }
    }

    private TrainingResult calculateTrainingStatistics(TrainingSettings trainingSettings, List<TrainingResult> list) {
        System.out.println("working on statistic...");
        TrainingResult trainingResult = new TrainingResult(trainingSettings);
        TrainingStatistics calculateIterations = TrainingStatistics.calculateIterations(list);
        trainingResult.setMSE(TrainingStatistics.calculateMSE(list));
        trainingResult.setIterationStat(calculateIterations);
        return trainingResult;
    }

    private void testNeuralNetwork(MultiLayerPerceptron multiLayerPerceptron, DataSet dataSet) {
        Iterator<DataSetRow> it = dataSet.getRows().iterator();
        while (it.hasNext()) {
            multiLayerPerceptron.setInput(it.next().getInput());
            multiLayerPerceptron.calculate();
        }
    }
}
