/*
 * Decompiled with CFR 0.152.
 */
package edu.hitsz.c102c.cnn;

import edu.hitsz.c102c.cnn.CNN;
import edu.hitsz.c102c.cnn.Dataset;
import edu.hitsz.c102c.cnn.Layer;
import edu.hitsz.c102c.util.ConcurenceRunner;
import edu.hitsz.c102c.util.TimedTest;

public class RunCNN {
    public static void runCnn() {
        CNN.LayerBuilder builder = new CNN.LayerBuilder();
        builder.addLayer(Layer.buildInputLayer(new Layer.Size(28, 28)));
        builder.addLayer(Layer.buildConvLayer(6, new Layer.Size(5, 5)));
        builder.addLayer(Layer.buildSampLayer(new Layer.Size(2, 2)));
        builder.addLayer(Layer.buildConvLayer(12, new Layer.Size(5, 5)));
        builder.addLayer(Layer.buildSampLayer(new Layer.Size(2, 2)));
        builder.addLayer(Layer.buildOutputLayer(10));
        CNN cnn = new CNN(builder, 50);
        String fileName = "dataset/train.format";
        Dataset dataset = Dataset.load(fileName, ",", 784);
        cnn.train(dataset, 3);
        String modelName = "model/model.cnn";
        cnn.saveModel(modelName);
        dataset.clear();
        dataset = null;
        Dataset testset = Dataset.load("dataset/test.format", ",", -1);
        cnn.predict(testset, "dataset/test.predict");
    }

    public static void main(String[] args) {
        new TimedTest(new TimedTest.TestTask(){

            @Override
            public void process() {
                RunCNN.runCnn();
            }
        }, 1).test();
        ConcurenceRunner.stop();
    }
}

