/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ResilientPropagation;
import org.neuroph.util.TransferFunctionType;

public class XorResilientPropagationSample
implements LearningEventListener {
    public static void main(String[] args) {
        new XorResilientPropagationSample().run();
    }

    public void run() {
        DataSet trainingSet = new DataSet(2, 1);
        trainingSet.addRow(new DataSetRow(new double[]{0.0, 0.0}, new double[]{0.0}));
        trainingSet.addRow(new DataSetRow(new double[]{0.0, 1.0}, new double[]{1.0}));
        trainingSet.addRow(new DataSetRow(new double[]{1.0, 0.0}, new double[]{1.0}));
        trainingSet.addRow(new DataSetRow(new double[]{1.0, 1.0}, new double[]{0.0}));
        MultiLayerPerceptron myMlPerceptron = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, 2, 3, 1);
        myMlPerceptron.setLearningRule(new ResilientPropagation());
        Object learningRule = myMlPerceptron.getLearningRule();
        ((LearningRule)learningRule).addListener(this);
        System.out.println("Training neural network...");
        myMlPerceptron.learn(trainingSet);
        int iterations = ((SupervisedLearning)myMlPerceptron.getLearningRule()).getCurrentIteration();
        System.out.println("Learned in " + iterations + " iterations");
        System.out.println("Testing trained neural network");
        XorResilientPropagationSample.testNeuralNetwork(myMlPerceptron, trainingSet);
    }

    public static void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {
        for (DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();
            System.out.print("Input: " + Arrays.toString(testSetRow.getInput()));
            System.out.println(" Output: " + Arrays.toString(networkOutput));
        }
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType() != LearningEvent.Type.LEARNING_STOPPED) {
            System.out.println(bp.getCurrentIteration() + ". iteration : " + bp.getTotalNetworkError());
        }
    }
}

