package org.neuroph.samples;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ResilientPropagation;
import org.neuroph.util.TransferFunctionType;

/* loaded from: input_file:org/neuroph/samples/XorResilientPropagationSample.class */
public class XorResilientPropagationSample implements LearningEventListener {
    public static void main(String[] strArr) {
        new XorResilientPropagationSample().run();
    }

    public void run() {
        DataSet dataSet = new DataSet(2, 1);
        dataSet.addRow(new DataSetRow(new double[]{0.0d, 0.0d}, new double[]{0.0d}));
        dataSet.addRow(new DataSetRow(new double[]{0.0d, 1.0d}, new double[]{1.0d}));
        dataSet.addRow(new DataSetRow(new double[]{1.0d, 0.0d}, new double[]{1.0d}));
        dataSet.addRow(new DataSetRow(new double[]{1.0d, 1.0d}, new double[]{0.0d}));
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 2, 3, 1);
        multiLayerPerceptron.setLearningRule(new ResilientPropagation());
        multiLayerPerceptron.getLearningRule().addListener(this);
        System.out.println("Training neural network...");
        multiLayerPerceptron.learn(dataSet);
        System.out.println("Learned in " + multiLayerPerceptron.getLearningRule().getCurrentIteration() + " iterations");
        System.out.println("Testing trained neural network");
        testNeuralNetwork(multiLayerPerceptron, dataSet);
    }

    public static void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            double[] output = neuralNetwork.getOutput();
            System.out.print("Input: " + Arrays.toString(dataSetRow.getInput()));
            System.out.println(" Output: " + Arrays.toString(output));
        }
    }

    @Override // org.neuroph.core.events.LearningEventListener
    public void handleLearningEvent(LearningEvent learningEvent) {
        BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
        if (learningEvent.getEventType() != LearningEvent.Type.LEARNING_STOPPED) {
            System.out.println(backPropagation.getCurrentIteration() + ". iteration : " + backPropagation.getTotalNetworkError());
        }
    }
}
