package org.neuroph.samples.adalineDigits;

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;

/* loaded from: input_file:org/neuroph/samples/adalineDigits/DigitsRecognition.class */
public class DigitsRecognition {
    public static void main(String[] strArr) {
        DataSet generateTraining = generateTraining();
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(35, 19, Data.DIGITS.length);
        BackPropagation learningRule = multiLayerPerceptron.getLearningRule();
        learningRule.setLearningRate(0.5d);
        learningRule.setMaxError(0.001d);
        learningRule.setMaxIterations(5000);
        learningRule.addListener(new LearningEventListener() { // from class: org.neuroph.samples.adalineDigits.DigitsRecognition.1
            @Override // org.neuroph.core.events.LearningEventListener
            public void handleLearningEvent(LearningEvent learningEvent) {
                BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
                if (!learningEvent.getEventType().equals(LearningEvent.Type.LEARNING_STOPPED)) {
                    System.out.println("Iteration: " + backPropagation.getCurrentIteration() + " | Network error: " + backPropagation.getTotalNetworkError());
                    return;
                }
                System.out.println();
                System.out.println("Training completed in " + backPropagation.getCurrentIteration() + " iterations");
                System.out.println("With total error " + backPropagation.getTotalNetworkError() + '\n');
            }
        });
        multiLayerPerceptron.learn(generateTraining);
        testNeuralNetwork(multiLayerPerceptron, generateTraining);
    }

    public static void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        System.out.println("--------------------------------------------------------------------");
        System.out.println("***********************TESTING NEURAL NETWORK***********************");
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            int maxOutput = maxOutput(neuralNetwork.getOutput());
            String[] convertDataIntoImage = Data.convertDataIntoImage(dataSetRow.getInput());
            for (int i = 0; i < convertDataIntoImage.length; i++) {
                if (i != convertDataIntoImage.length - 1) {
                    System.out.println(convertDataIntoImage[i]);
                } else {
                    System.out.println(convertDataIntoImage[i] + "----> " + maxOutput);
                }
            }
            System.out.println("");
        }
    }

    public static DataSet generateTraining() {
        DataSet dataSet = new DataSet(35, Data.DIGITS.length);
        for (int i = 0; i < Data.DIGITS.length; i++) {
            double[] input = Data.convertImageIntoData(Data.DIGITS[i]).getInput();
            double[] dArr = new double[Data.DIGITS.length];
            for (int i2 = 0; i2 < Data.DIGITS.length; i2++) {
                if (i2 == i) {
                    dArr[i2] = 1.0d;
                } else {
                    dArr[i2] = 0.0d;
                }
            }
            dataSet.addRow(new DataSetRow(input, dArr));
        }
        return dataSet;
    }

    public static int maxOutput(double[] dArr) {
        double d = dArr[0];
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        return i;
    }
}
