package org.neuroph.samples.forestCover;

import java.math.BigDecimal;
import java.math.RoundingMode;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;

/* loaded from: input_file:org/neuroph/samples/forestCover/TrainNetwork.class */
public class TrainNetwork implements LearningEventListener {
    private Config config;

    public TrainNetwork(Config config) {
        this.config = config;
    }

    public void createNeuralNetwork() {
        System.out.println("Creating neural network... ");
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(this.config.getInputCount(), this.config.getFirstHiddenLayerCount(), this.config.getSecondHiddenLayerCount(), this.config.getOutputCount());
        MomentumBackpropagation momentumBackpropagation = (MomentumBackpropagation) multiLayerPerceptron.getLearningRule();
        momentumBackpropagation.setLearningRate(0.01d);
        momentumBackpropagation.setMaxError(0.1d);
        momentumBackpropagation.setMaxIterations(1000);
        System.out.println("Saving neural network to file... ");
        multiLayerPerceptron.save(this.config.getTrainedNetworkFileName());
        System.out.println("Neural network successfully saved!");
    }

    public void train() {
        System.out.println("Training neural network... ");
        MultiLayerPerceptron multiLayerPerceptron = (MultiLayerPerceptron) NeuralNetwork.createFromFile(this.config.getTrainedNetworkFileName());
        DataSet load = DataSet.load(this.config.getNormalizedBalancedFileName());
        multiLayerPerceptron.getLearningRule().addListener(this);
        multiLayerPerceptron.learn(load);
        System.out.println("Saving trained neural network to file... ");
        multiLayerPerceptron.save(this.config.getTrainedNetworkFileName());
        System.out.println("Neural network successfully saved!");
    }

    private String formatDecimalNumber(double d) {
        return new BigDecimal(d).setScale(5, RoundingMode.HALF_UP).toString();
    }

    @Override // org.neuroph.core.events.LearningEventListener
    public void handleLearningEvent(LearningEvent learningEvent) {
        BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
        if (!learningEvent.getEventType().equals(LearningEvent.Type.LEARNING_STOPPED)) {
            System.out.println("Iteration: " + backPropagation.getCurrentIteration() + " | Network error: " + backPropagation.getTotalNetworkError());
            return;
        }
        double totalNetworkError = backPropagation.getTotalNetworkError();
        System.out.println("Training completed in " + backPropagation.getCurrentIteration() + " iterations, ");
        System.out.println("With total error: " + formatDecimalNumber(totalNetworkError));
    }
}
