package org.neuroph.samples.cifar10;

import java.io.IOException;
import java.util.Arrays;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.Evaluation;
import org.neuroph.imgrec.ColorMode;
import org.neuroph.imgrec.ImageRecognitionHelper;
import org.neuroph.imgrec.image.Dimension;
import org.neuroph.nnet.ConvolutionalNetwork;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ConvolutionalBackpropagation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/neuroph/samples/cifar10/Cifar10Example.class */
public class Cifar10Example {
    private static Logger LOG = LoggerFactory.getLogger(Cifar10Example.class);

    /* loaded from: input_file:org/neuroph/samples/cifar10/Cifar10Example$LearningListener.class */
    private static class LearningListener implements LearningEventListener {
        private final NeuralNetwork neuralNetwork;
        private DataSet testSet;
        long start = System.currentTimeMillis();

        public LearningListener(NeuralNetwork neuralNetwork, DataSet dataSet) {
            this.testSet = dataSet;
            this.neuralNetwork = neuralNetwork;
        }

        @Override // org.neuroph.core.events.LearningEventListener
        public void handleLearningEvent(LearningEvent learningEvent) {
            BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
            Cifar10Example.LOG.info("Epoch no#: [{}]. Error [{}]", Integer.valueOf(backPropagation.getCurrentIteration()), Double.valueOf(backPropagation.getTotalNetworkError()));
            Cifar10Example.LOG.info("Epoch execution time: {} sec", Double.valueOf((System.currentTimeMillis() - this.start) / 1000.0d));
            this.start = System.currentTimeMillis();
        }
    }

    public static void main(String[] strArr) throws IOException {
        LOG.info("{}-{}-{}", new Object[]{10, 15, 20});
        DataSet createImageDataSetFromFile = ImageRecognitionHelper.createImageDataSetFromFile("D:\\Doktorske\\Beograd\\Neuronske mreze - Zoran Sevarac\\Cifar 10\\train\\train_1000\\", Arrays.asList("airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"), "", ColorMode.COLOR_RGB, new Dimension(32, 32), "cifar", 1);
        ConvolutionalNetwork build = new ConvolutionalNetwork.Builder().withInputLayer(32, 32, 3).withConvolutionLayer(5, 5, 10).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, 15).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, 20).withFullConnectedLayer(10).build();
        ConvolutionalBackpropagation convolutionalBackpropagation = new ConvolutionalBackpropagation();
        convolutionalBackpropagation.setLearningRate(0.3d);
        convolutionalBackpropagation.setMaxError(0.03d);
        convolutionalBackpropagation.setMaxIterations(20);
        convolutionalBackpropagation.addListener(new LearningListener(build, createImageDataSetFromFile));
        convolutionalBackpropagation.setErrorFunction(new MeanSquaredError());
        build.setLearningRule(convolutionalBackpropagation);
        build.learn(createImageDataSetFromFile);
        Evaluation.runFullEvaluation(build, createImageDataSetFromFile);
    }
}
