package org.neuroph.samples;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;

/* loaded from: input_file:org/neuroph/samples/IrisClassificationSample.class */
public class IrisClassificationSample {

    /* loaded from: input_file:org/neuroph/samples/IrisClassificationSample$LearningListener.class */
    static class LearningListener implements LearningEventListener {
        LearningListener() {
        }

        @Override // org.neuroph.core.events.LearningEventListener
        public void handleLearningEvent(LearningEvent learningEvent) {
            BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
            System.out.println("Current iteration: " + backPropagation.getCurrentIteration());
            System.out.println("Error: " + backPropagation.getTotalNetworkError());
        }
    }

    public static void main(String[] strArr) {
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(4, 16, 3);
        DataSet createFromFile = DataSet.createFromFile("data_sets/iris_data_normalised.txt", 4, 3, ",");
        multiLayerPerceptron.getLearningRule().addListener(new LearningListener());
        multiLayerPerceptron.getLearningRule().setLearningRate(0.5d);
        multiLayerPerceptron.getLearningRule().setMaxError(0.01d);
        multiLayerPerceptron.getLearningRule().setMaxIterations(30000);
        multiLayerPerceptron.learn(createFromFile);
        multiLayerPerceptron.save("irisNet.nnet");
        System.out.println("Done training.");
        System.out.println("Testing network...");
    }

    public static void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            double[] output = neuralNetwork.getOutput();
            System.out.print("Input: " + Arrays.toString(dataSetRow.getInput()));
            System.out.println(" Output: " + Arrays.toString(output));
        }
    }
}
