package org.neuroph.samples.uci;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;

/* loaded from: input_file:org/neuroph/samples/uci/ShuttleLandingControlSample.class */
public class ShuttleLandingControlSample implements LearningEventListener {
    public static void main(String[] strArr) {
        new ShuttleLandingControlSample().run();
    }

    public void run() {
        System.out.println("Creating training set...");
        DataSet createFromFile = DataSet.createFromFile("data_sets/shuttle_landing_control_data.txt", 15, 2, ",", false);
        System.out.println("Creating neural network...");
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(15, 16, 2);
        MomentumBackpropagation momentumBackpropagation = (MomentumBackpropagation) multiLayerPerceptron.getLearningRule();
        momentumBackpropagation.addListener(this);
        momentumBackpropagation.setLearningRate(0.2d);
        momentumBackpropagation.setMaxError(0.01d);
        System.out.println("Training network...");
        multiLayerPerceptron.learn(createFromFile);
        System.out.println("Training completed.");
        System.out.println("Testing network...");
        testNeuralNetwork(multiLayerPerceptron, createFromFile);
        System.out.println("Saving network");
        multiLayerPerceptron.save("MyNeuralNetShuttle.nnet");
        System.out.println("Done.");
    }

    public void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            double[] output = neuralNetwork.getOutput();
            System.out.print("Input: " + Arrays.toString(dataSetRow.getInput()));
            System.out.println(" Output: " + Arrays.toString(output));
        }
    }

    @Override // org.neuroph.core.events.LearningEventListener
    public void handleLearningEvent(LearningEvent learningEvent) {
        BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
        System.out.println(backPropagation.getCurrentIteration() + ". iteration | Total network error: " + backPropagation.getTotalNetworkError());
    }
}
