/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.autotrain;

import java.util.ArrayList;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.neuroph.contrib.autotrain.Range;
import org.neuroph.contrib.autotrain.TrainingResult;
import org.neuroph.contrib.autotrain.TrainingSettings;
import org.neuroph.contrib.autotrain.TrainingStatistics;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.util.TransferFunctionType;

public class AutoTrainer {
    private static final Logger LOGGER = Logger.getLogger(AutoTrainer.class.getName());
    private List<TrainingSettings> trainingSettingsList = new ArrayList<TrainingSettings>();
    private List<TrainingResult> results = new ArrayList<TrainingResult>();
    private int maxHiddenNeurons;
    private int minHiddenNeurons;
    private int hiddenNeuronsStep = 1;
    private double minLearningRate;
    private double maxLearningRate;
    private double learningRateStep = 0.1;
    private double maxMomentum = 0.9;
    private int splitPercentage = 100;
    private boolean splitTrainTest = false;
    private double maxErrorMin;
    private double maxErrorMax;
    private double maxErrorStep = 0.01;
    private int maxIterations;
    private TransferFunctionType transferFunction = TransferFunctionType.SIGMOID;
    private boolean generateStatistics = false;
    private int repeat = 1;

    public AutoTrainer() {
    }

    public AutoTrainer(double maxError, int maxIterations) {
        this();
        this.maxErrorMin = maxError;
        this.maxIterations = maxIterations;
    }

    public List<TrainingResult> getResults() {
        return this.results;
    }

    public AutoTrainer setHiddenNeurons(int minHiddenNeurons, int maxHiddenNeurons, int step) {
        this.minHiddenNeurons = minHiddenNeurons;
        this.maxHiddenNeurons = maxHiddenNeurons;
        this.hiddenNeuronsStep = step;
        return this;
    }

    public AutoTrainer setHiddenNeurons(int minHiddenNeurons, int maxHiddenNeurons) {
        this.minHiddenNeurons = minHiddenNeurons;
        this.maxHiddenNeurons = maxHiddenNeurons;
        return this;
    }

    public AutoTrainer setHiddenNeurons(Range range, int step) {
        this.minHiddenNeurons = (int)range.getMin();
        this.maxHiddenNeurons = (int)range.getMax();
        this.hiddenNeuronsStep = step;
        return this;
    }

    public AutoTrainer setHiddenNeurons(Range range) {
        this.minHiddenNeurons = (int)range.getMin();
        this.maxHiddenNeurons = (int)range.getMax();
        return this;
    }

    public AutoTrainer setMaxError(double maxErrorMin, double maxErrorMax, double step) {
        this.maxErrorMin = maxErrorMin;
        this.maxErrorMax = maxErrorMax;
        this.maxErrorStep = step;
        return this;
    }

    public AutoTrainer setMaxError(double maxErrorMin, double maxErrorMax) {
        this.maxErrorMin = maxErrorMin;
        this.maxErrorMax = maxErrorMax;
        return this;
    }

    public AutoTrainer setMaxError(Range range, double step) {
        this.maxErrorMin = range.getMin();
        this.maxErrorMax = range.getMax();
        this.maxErrorStep = step;
        return this;
    }

    public AutoTrainer setMaxError(Range range) {
        this.maxErrorMin = range.getMin();
        this.maxErrorMax = range.getMax();
        return this;
    }

    public AutoTrainer setLearningRate(double minLearningRate, double maxLearningRate, double step) {
        this.minLearningRate = minLearningRate;
        this.maxLearningRate = maxLearningRate;
        this.learningRateStep = step;
        return this;
    }

    public AutoTrainer setLearningRate(double minLearningRate, double maxLearningRate) {
        this.minLearningRate = minLearningRate;
        this.maxLearningRate = maxLearningRate;
        return this;
    }

    public AutoTrainer setLearningRate(Range range, double step) {
        this.minLearningRate = range.getMin();
        this.maxLearningRate = range.getMax();
        this.learningRateStep = step;
        return this;
    }

    public AutoTrainer setLearningRate(Range range) {
        this.minLearningRate = range.getMin();
        this.maxLearningRate = range.getMax();
        return this;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public AutoTrainer setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
        return this;
    }

    public TransferFunctionType getTransferFunction() {
        return this.transferFunction;
    }

    public AutoTrainer setTransferFunction(TransferFunctionType transferFunction) {
        this.transferFunction = transferFunction;
        return this;
    }

    public void setMaxMomentum(double maxMomentum) {
        this.maxMomentum = maxMomentum;
    }

    public AutoTrainer repeat(int times) {
        this.repeat = times;
        this.generateStatistics = true;
        return this;
    }

    protected boolean generatesStatistics() {
        return this.generateStatistics;
    }

    public AutoTrainer setTrainTestSplit(int trainingPrecent) {
        this.splitPercentage = trainingPrecent;
        this.splitTrainTest = true;
        return this;
    }

    protected boolean isSplitForTesting() {
        return this.splitTrainTest;
    }

    private void generateTrainingSettings() {
        double pom = this.minLearningRate;
        for (int hiddenNeurons = this.minHiddenNeurons; hiddenNeurons <= this.maxHiddenNeurons; hiddenNeurons += this.hiddenNeuronsStep) {
            for (double learningRate = this.minLearningRate; learningRate <= this.maxLearningRate; learningRate += this.learningRateStep) {
                for (double maxError = this.maxErrorMin; maxError <= this.maxErrorMax; maxError += this.maxErrorStep) {
                    TrainingSettings ts = new TrainingSettings().setHiddenNeurons(hiddenNeurons).setLearningRate(learningRate).setMaxError(maxError).setMaxIterations(this.getMaxIterations());
                    this.trainingSettingsList.add(ts);
                }
            }
            this.minLearningRate = pom;
        }
        LOGGER.log(Level.INFO, "Generated : {0} settings.", this.trainingSettingsList.size());
    }

    public void train(DataSet dataSet) {
        DataSet testSet;
        DataSet trainingSet;
        this.generateTrainingSettings();
        ArrayList<TrainingResult> statResults = null;
        if (this.splitTrainTest) {
            List<DataSet> dataSplit = dataSet.split(this.splitPercentage, 100 - this.splitPercentage);
            trainingSet = dataSplit.get(0);
            testSet = dataSplit.get(1);
        } else {
            trainingSet = dataSet;
            testSet = dataSet;
        }
        if (this.generateStatistics) {
            statResults = new ArrayList<TrainingResult>();
        }
        int trainingNo = 0;
        for (TrainingSettings trainingSetting : this.trainingSettingsList) {
            System.out.println("-----------------------------------------------------------------------------------");
            System.out.println("##TRAINING: " + ++trainingNo);
            trainingSetting.setTrainingSet(this.splitPercentage);
            trainingSetting.setTestSet(100 - this.splitPercentage);
            for (int subtrainNo = 1; subtrainNo <= this.repeat; ++subtrainNo) {
                System.out.println("#SubTraining: " + subtrainNo);
                MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(dataSet.getInputSize(), trainingSetting.getHiddenNeurons(), dataSet.getOutputSize());
                BackPropagation bp = (BackPropagation)neuralNet.getLearningRule();
                bp.setLearningRate(trainingSetting.getLearningRate());
                bp.setMaxError(trainingSetting.getMaxError());
                bp.setMaxIterations(trainingSetting.getMaxIterations());
                neuralNet.learn(trainingSet);
                ConfusionMatrix cm = new ConfusionMatrix(new String[]{""});
                TrainingResult result = new TrainingResult(trainingSetting, bp.getTotalNetworkError(), bp.getCurrentIteration(), cm);
                System.out.println(subtrainNo + ") iterations: " + bp.getCurrentIteration());
                if (this.generateStatistics) {
                    statResults.add(result);
                    continue;
                }
                this.results.add(result);
            }
            if (!this.generateStatistics) continue;
            TrainingResult trainingStats = this.calculateTrainingStatistics(trainingSetting, statResults);
            this.results.add(trainingStats);
            statResults.clear();
        }
    }

    private TrainingResult calculateTrainingStatistics(TrainingSettings ts, List<TrainingResult> results) {
        System.out.println("working on statistic...");
        TrainingResult result = new TrainingResult(ts);
        TrainingStatistics iterationsStat = TrainingStatistics.calculateIterations(results);
        TrainingStatistics MSEStat = TrainingStatistics.calculateMSE(results);
        result.setMSE(MSEStat);
        result.setIterationStat(iterationsStat);
        return result;
    }

    private void testNeuralNetwork(MultiLayerPerceptron neuralNet, DataSet testSet) {
        for (DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
        }
    }
}

