package org.neuroph.samples.mnist.learn;

import java.io.IOException;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.Evaluation;
import org.neuroph.nnet.ConvolutionalNetwork;
import org.neuroph.nnet.comp.Dimension2D;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ConvolutionalBackpropagation;
import org.neuroph.samples.convolution.mnist.MNISTDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/neuroph/samples/mnist/learn/CnnMNIST.class */
public class CnnMNIST {
    private static Logger LOG = LoggerFactory.getLogger(CnnMNIST.class);

    /* loaded from: input_file:org/neuroph/samples/mnist/learn/CnnMNIST$LearningListener.class */
    private static class LearningListener implements LearningEventListener {
        private final NeuralNetwork neuralNetwork;
        private DataSet testSet;
        long start = System.currentTimeMillis();

        public LearningListener(NeuralNetwork neuralNetwork, DataSet dataSet) {
            this.testSet = dataSet;
            this.neuralNetwork = neuralNetwork;
        }

        @Override // org.neuroph.core.events.LearningEventListener
        public void handleLearningEvent(LearningEvent learningEvent) {
            BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
            CnnMNIST.LOG.info("Epoch no#: [{}]. Error [{}]", Integer.valueOf(backPropagation.getCurrentIteration()), Double.valueOf(backPropagation.getTotalNetworkError()));
            CnnMNIST.LOG.info("Epoch execution time: {} sec", Double.valueOf((System.currentTimeMillis() - this.start) / 1000.0d));
            this.start = System.currentTimeMillis();
        }
    }

    public static void main(String[] strArr) {
        try {
            int parseInt = Integer.parseInt(strArr[3]);
            int parseInt2 = Integer.parseInt(strArr[4]);
            int parseInt3 = Integer.parseInt(strArr[5]);
            LOG.info("{}-{}-{}", new Object[]{Integer.valueOf(parseInt), Integer.valueOf(parseInt2), Integer.valueOf(parseInt3)});
            String concat = "C:\\Users\\jecak_000\\Documents\\Neuroph\\neuroph_novaVerzija\\neurophNoviPull\\neuroph-2.9\\Samples\\".concat(MNISTDataSet.TRAIN_LABEL_NAME);
            String concat2 = "C:\\Users\\jecak_000\\Documents\\Neuroph\\neuroph_novaVerzija\\neurophNoviPull\\neuroph-2.9\\Samples\\".concat(MNISTDataSet.TRAIN_IMAGE_NAME);
            String concat3 = "C:\\Users\\jecak_000\\Documents\\Neuroph\\neuroph_novaVerzija\\neurophNoviPull\\neuroph-2.9\\Samples\\".concat(MNISTDataSet.TEST_LABEL_NAME);
            String concat4 = "C:\\Users\\jecak_000\\Documents\\Neuroph\\neuroph_novaVerzija\\neurophNoviPull\\neuroph-2.9\\Samples\\".concat(MNISTDataSet.TEST_IMAGE_NAME);
            DataSet createFromFile = MNISTDataSet.createFromFile(concat, concat2, 100);
            DataSet createFromFile2 = MNISTDataSet.createFromFile(concat3, concat4, 100);
            new Dimension2D(32, 32);
            new Dimension2D(5, 5);
            new Dimension2D(2, 2);
            ConvolutionalNetwork build = new ConvolutionalNetwork.Builder().withInputLayer(32, 32, 1).withConvolutionLayer(5, 5, parseInt).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, parseInt2).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, parseInt3).withFullConnectedLayer(10).build();
            ConvolutionalBackpropagation convolutionalBackpropagation = new ConvolutionalBackpropagation();
            convolutionalBackpropagation.setLearningRate(0.2d);
            convolutionalBackpropagation.setMaxError(0.01d);
            convolutionalBackpropagation.setMaxIterations(10000);
            convolutionalBackpropagation.addListener(new LearningListener(build, createFromFile2));
            convolutionalBackpropagation.setErrorFunction(new MeanSquaredError());
            build.setLearningRule(convolutionalBackpropagation);
            build.learn(createFromFile);
            Evaluation.runFullEvaluation(build, createFromFile2);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}
