/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.mnist.learn;

import java.io.IOException;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.Evaluation;
import org.neuroph.nnet.ConvolutionalNetwork;
import org.neuroph.nnet.comp.Dimension2D;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ConvolutionalBackpropagation;
import org.neuroph.samples.convolution.mnist.MNISTDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CnnMNIST {
    private static Logger LOG = LoggerFactory.getLogger(CnnMNIST.class);

    public static void main(String[] args) {
        try {
            int maxIter = 10000;
            double maxError = 0.01;
            double learningRate = 0.2;
            int layer1 = Integer.parseInt(args[3]);
            int layer2 = Integer.parseInt(args[4]);
            int layer3 = Integer.parseInt(args[5]);
            LOG.info("{}-{}-{}", new Object[]{layer1, layer2, layer3});
            String putanja = "C:\\Users\\jecak_000\\Documents\\Neuroph\\neuroph_novaVerzija\\neurophNoviPull\\neuroph-2.9\\Samples\\";
            String labelName = putanja.concat("data_sets/train-labels.idx1-ubyte");
            String trainImage = putanja.concat("data_sets/train-images.idx3-ubyte");
            String testLabel = putanja.concat("data_sets/t10k-labels.idx1-ubyte");
            String testImage = putanja.concat("data_sets/t10k-images.idx3-ubyte");
            DataSet trainSet = MNISTDataSet.createFromFile(labelName, trainImage, 100);
            DataSet testSet = MNISTDataSet.createFromFile(testLabel, testImage, 100);
            Dimension2D inputDimension = new Dimension2D(32, 32);
            Dimension2D convolutionKernel = new Dimension2D(5, 5);
            Dimension2D poolingKernel = new Dimension2D(2, 2);
            ConvolutionalNetwork convolutionNetwork = new ConvolutionalNetwork.Builder().withInputLayer(32, 32, 1).withConvolutionLayer(5, 5, layer1).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, layer2).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, layer3).withFullConnectedLayer(10).build();
            ConvolutionalBackpropagation backPropagation = new ConvolutionalBackpropagation();
            backPropagation.setLearningRate(learningRate);
            backPropagation.setMaxError(maxError);
            backPropagation.setMaxIterations(maxIter);
            backPropagation.addListener(new LearningListener(convolutionNetwork, testSet));
            backPropagation.setErrorFunction(new MeanSquaredError());
            convolutionNetwork.setLearningRule(backPropagation);
            convolutionNetwork.learn(trainSet);
            Evaluation.runFullEvaluation(convolutionNetwork, testSet);
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static class LearningListener
    implements LearningEventListener {
        private final NeuralNetwork neuralNetwork;
        private DataSet testSet;
        long start = System.currentTimeMillis();

        public LearningListener(NeuralNetwork neuralNetwork, DataSet testSet) {
            this.testSet = testSet;
            this.neuralNetwork = neuralNetwork;
        }

        @Override
        public void handleLearningEvent(LearningEvent event) {
            BackPropagation bp = (BackPropagation)event.getSource();
            LOG.info("Epoch no#: [{}]. Error [{}]", (Object)bp.getCurrentIteration(), (Object)bp.getTotalNetworkError());
            LOG.info("Epoch execution time: {} sec", (Object)((double)(System.currentTimeMillis() - this.start) / 1000.0));
            this.start = System.currentTimeMillis();
        }
    }
}

