package org.neuroph.samples.convolution.mnist;

import java.io.IOException;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.Evaluation;
import org.neuroph.nnet.ConvolutionalNetwork;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ConvolutionalBackpropagation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/neuroph/samples/convolution/mnist/MNISTExample.class */
public class MNISTExample {
    private static Logger LOG = LoggerFactory.getLogger(MNISTExample.class);

    /* loaded from: input_file:org/neuroph/samples/convolution/mnist/MNISTExample$LearningListener.class */
    static class LearningListener implements LearningEventListener {
        long start = System.currentTimeMillis();

        LearningListener() {
        }

        @Override // org.neuroph.core.events.LearningEventListener
        public void handleLearningEvent(LearningEvent learningEvent) {
            BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
            MNISTExample.LOG.info("Current iteration: " + backPropagation.getCurrentIteration());
            MNISTExample.LOG.info("Error: " + backPropagation.getTotalNetworkError());
            MNISTExample.LOG.info("Calculation time: " + ((System.currentTimeMillis() - this.start) / 1000.0d));
            this.start = System.currentTimeMillis();
        }
    }

    public static void main(String[] strArr) {
        try {
            DataSet createFromFile = MNISTDataSet.createFromFile(MNISTDataSet.TRAIN_LABEL_NAME, MNISTDataSet.TRAIN_IMAGE_NAME, 60);
            DataSet createFromFile2 = MNISTDataSet.createFromFile(MNISTDataSet.TEST_LABEL_NAME, MNISTDataSet.TEST_IMAGE_NAME, 10);
            ConvolutionalNetwork build = new ConvolutionalNetwork.Builder().withInputLayer(32, 32, 1).withConvolutionLayer(5, 5, 6).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, 16).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, 120).withFullConnectedLayer(84).withFullConnectedLayer(10).build();
            ConvolutionalBackpropagation convolutionalBackpropagation = new ConvolutionalBackpropagation();
            convolutionalBackpropagation.setLearningRate(0.001d);
            convolutionalBackpropagation.setMaxError(0.01d);
            convolutionalBackpropagation.setErrorFunction(new MeanSquaredError());
            build.setLearningRule(convolutionalBackpropagation);
            convolutionalBackpropagation.addListener(new LearningListener());
            build.learn(createFromFile);
            Evaluation.runFullEvaluation(build, createFromFile2);
            build.save("mnist.nnet");
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
