package org.neuroph.contrib.samples.timeseries;

import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;

/* loaded from: input_file:org/neuroph/contrib/samples/timeseries/TestTimeSeries.class */
public class TestTimeSeries implements LearningEventListener {
    NeuralNetwork<?> neuralNet;
    DataSet trainingSet;

    public static void main(String[] strArr) {
        TestTimeSeries testTimeSeries = new TestTimeSeries();
        testTimeSeries.train();
        testTimeSeries.testNeuralNetwork();
    }

    public void train() {
        this.neuralNet = new MultiLayerPerceptron(TransferFunctionType.TANH, 5, 10, 1);
        MomentumBackpropagation momentumBackpropagation = (MomentumBackpropagation) this.neuralNet.getLearningRule();
        momentumBackpropagation.setLearningRate(0.2d);
        momentumBackpropagation.setMomentum(0.5d);
        momentumBackpropagation.addListener(this);
        this.trainingSet = DataSet.createFromFile("C:\\timeseries\\BSW15", 5, 1, "\t", false);
        this.neuralNet.learn(this.trainingSet);
        System.out.println("Done training.");
    }

    public void testNeuralNetwork() {
        System.out.println("Testing network...");
        for (DataSetRow dataSetRow : this.trainingSet.getRows()) {
            this.neuralNet.setInput(dataSetRow.getInput());
            this.neuralNet.calculate();
            double[] output = this.neuralNet.getOutput();
            System.out.print("Input: " + Arrays.toString(dataSetRow.getInput()));
            System.out.println(" Output: " + Arrays.toString(output));
        }
    }

    @Override // org.neuroph.core.events.LearningEventListener
    public void handleLearningEvent(LearningEvent learningEvent) {
        SupervisedLearning supervisedLearning = (SupervisedLearning) learningEvent.getSource();
        System.out.println("Training, Network Epoch " + supervisedLearning.getCurrentIteration() + ", Error:" + supervisedLearning.getTotalNetworkError());
    }
}
