/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.convolution.mnist;

import java.io.IOException;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.Evaluation;
import org.neuroph.nnet.ConvolutionalNetwork;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.ConvolutionalBackpropagation;
import org.neuroph.samples.convolution.mnist.MNISTDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MNISTExample {
    private static Logger LOG = LoggerFactory.getLogger(MNISTExample.class);

    public static void main(String[] args) {
        try {
            DataSet trainSet = MNISTDataSet.createFromFile("data_sets/train-labels.idx1-ubyte", "data_sets/train-images.idx3-ubyte", 60);
            DataSet testSet = MNISTDataSet.createFromFile("data_sets/t10k-labels.idx1-ubyte", "data_sets/t10k-images.idx3-ubyte", 10);
            ConvolutionalNetwork convolutionNetwork = new ConvolutionalNetwork.Builder().withInputLayer(32, 32, 1).withConvolutionLayer(5, 5, 6).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, 16).withPoolingLayer(2, 2).withConvolutionLayer(5, 5, 120).withFullConnectedLayer(84).withFullConnectedLayer(10).build();
            ConvolutionalBackpropagation backPropagation = new ConvolutionalBackpropagation();
            backPropagation.setLearningRate(0.001);
            backPropagation.setMaxError(0.01);
            backPropagation.setErrorFunction(new MeanSquaredError());
            convolutionNetwork.setLearningRule(backPropagation);
            backPropagation.addListener(new LearningListener());
            convolutionNetwork.learn(trainSet);
            Evaluation.runFullEvaluation(convolutionNetwork, testSet);
            convolutionNetwork.save("mnist.nnet");
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    static class LearningListener
    implements LearningEventListener {
        long start = System.currentTimeMillis();

        LearningListener() {
        }

        @Override
        public void handleLearningEvent(LearningEvent event) {
            BackPropagation bp = (BackPropagation)event.getSource();
            LOG.info("Current iteration: " + bp.getCurrentIteration());
            LOG.info("Error: " + bp.getTotalNetworkError());
            LOG.info("Calculation time: " + (double)(System.currentTimeMillis() - this.start) / 1000.0);
            this.start = System.currentTimeMillis();
        }
    }
}

