/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.forestCover;

import java.math.BigDecimal;
import java.math.RoundingMode;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.samples.forestCover.Config;

public class TrainNetwork
implements LearningEventListener {
    private Config config;

    public TrainNetwork(Config config) {
        this.config = config;
    }

    public void createNeuralNetwork() {
        System.out.println("Creating neural network... ");
        MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(this.config.getInputCount(), this.config.getFirstHiddenLayerCount(), this.config.getSecondHiddenLayerCount(), this.config.getOutputCount());
        MomentumBackpropagation learningRule = (MomentumBackpropagation)neuralNet.getLearningRule();
        learningRule.setLearningRate(0.01);
        learningRule.setMaxError(0.1);
        learningRule.setMaxIterations(1000);
        System.out.println("Saving neural network to file... ");
        neuralNet.save(this.config.getTrainedNetworkFileName());
        System.out.println("Neural network successfully saved!");
    }

    public void train() {
        System.out.println("Training neural network... ");
        MultiLayerPerceptron neuralNet = (MultiLayerPerceptron)NeuralNetwork.createFromFile(this.config.getTrainedNetworkFileName());
        DataSet dataSet = DataSet.load(this.config.getNormalizedBalancedFileName());
        ((BackPropagation)neuralNet.getLearningRule()).addListener(this);
        neuralNet.learn(dataSet);
        System.out.println("Saving trained neural network to file... ");
        neuralNet.save(this.config.getTrainedNetworkFileName());
        System.out.println("Neural network successfully saved!");
    }

    private String formatDecimalNumber(double number) {
        return new BigDecimal(number).setScale(5, RoundingMode.HALF_UP).toString();
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType().equals((Object)LearningEvent.Type.LEARNING_STOPPED)) {
            double error = bp.getTotalNetworkError();
            System.out.println("Training completed in " + bp.getCurrentIteration() + " iterations, ");
            System.out.println("With total error: " + this.formatDecimalNumber(error));
        } else {
            System.out.println("Iteration: " + bp.getCurrentIteration() + " | Network error: " + bp.getTotalNetworkError());
        }
    }
}

