package org.neuroph.samples;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.RBFNetwork;
import org.neuroph.nnet.learning.LMS;
import org.neuroph.nnet.learning.RBFLearning;

/* loaded from: input_file:org/neuroph/samples/RBFClassificationSample.class */
public class RBFClassificationSample implements LearningEventListener {
    public static void main(String[] strArr) {
        new RBFClassificationSample().run();
    }

    public void run() {
        RBFNetwork rBFNetwork = new RBFNetwork(1, 15, 1);
        DataSet createFromFile = DataSet.createFromFile("data_sets/sine.csv", 1, 1, ",", false);
        RBFLearning rBFLearning = (RBFLearning) rBFNetwork.getLearningRule();
        rBFLearning.setLearningRate(0.02d);
        rBFLearning.setMaxError(0.01d);
        rBFLearning.addListener(this);
        rBFNetwork.learn(createFromFile);
        System.out.println("Done training.");
        System.out.println("Testing network...");
        testNeuralNetwork(rBFNetwork, createFromFile);
    }

    public void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            double[] output = neuralNetwork.getOutput();
            System.out.print("Input: " + Arrays.toString(dataSetRow.getInput()));
            System.out.println(" Output: " + Arrays.toString(output));
        }
    }

    @Override // org.neuroph.core.events.LearningEventListener
    public void handleLearningEvent(LearningEvent learningEvent) {
        LMS lms = (LMS) learningEvent.getSource();
        System.out.println(lms.getCurrentIteration() + ". iteration | Total network error: " + lms.getTotalNetworkError());
    }
}
