/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.cifar10;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.Evaluation;
import org.neuroph.imgrec.ColorMode;
import org.neuroph.imgrec.ImageRecognitionHelper;
import org.neuroph.imgrec.image.Dimension;
import org.neuroph.nnet.ConvolutionalNetwork;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ConvolutionalBackpropagation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Cifar10Example {
    private static Logger LOG = LoggerFactory.getLogger(Cifar10Example.class);

    public static void main(String[] args) throws IOException {
        DataSet trainSet;
        int maxIter = 20;
        double maxError = 0.03;
        double learningRate = 0.3;
        int layer1 = 10;
        int layer2 = 15;
        int layer3 = 20;
        LOG.info("{}-{}-{}", new Object[]{layer1, layer2, layer3});
        List<String> labels = Arrays.asList("airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck");
        DataSet testSet = trainSet = ImageRecognitionHelper.createImageDataSetFromFile("D:\\Doktorske\\Beograd\\Neuronske mreze - Zoran Sevarac\\Cifar 10\\train\\train_1000\\", labels, "", ColorMode.COLOR_RGB, new Dimension(32, 32), "cifar", 1);
        ConvolutionalNetwork convolutionNetwork = new ConvolutionalNetwork.Builder().withInputLayer(32, 32, 3).withConvolutionLayer(5, 5, layer1).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, layer2).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, layer3).withFullConnectedLayer(10).build();
        ConvolutionalBackpropagation backPropagation = new ConvolutionalBackpropagation();
        backPropagation.setLearningRate(learningRate);
        backPropagation.setMaxError(maxError);
        backPropagation.setMaxIterations(maxIter);
        backPropagation.addListener(new LearningListener(convolutionNetwork, testSet));
        backPropagation.setErrorFunction(new MeanSquaredError());
        convolutionNetwork.setLearningRule(backPropagation);
        convolutionNetwork.learn(trainSet);
        Evaluation.runFullEvaluation(convolutionNetwork, testSet);
    }

    private static class LearningListener
    implements LearningEventListener {
        private final NeuralNetwork neuralNetwork;
        private DataSet testSet;
        long start = System.currentTimeMillis();

        public LearningListener(NeuralNetwork neuralNetwork, DataSet testSet) {
            this.testSet = testSet;
            this.neuralNetwork = neuralNetwork;
        }

        @Override
        public void handleLearningEvent(LearningEvent event) {
            BackPropagation bp = (BackPropagation)event.getSource();
            LOG.info("Epoch no#: [{}]. Error [{}]", (Object)bp.getCurrentIteration(), (Object)bp.getTotalNetworkError());
            LOG.info("Epoch execution time: {} sec", (Object)((double)(System.currentTimeMillis() - this.start) / 1000.0));
            this.start = System.currentTimeMillis();
        }
    }
}

