package edu.hitsz.c102c.cnn;

import edu.hitsz.c102c.cnn.CNN;
import edu.hitsz.c102c.cnn.Layer;
import edu.hitsz.c102c.util.ConcurenceRunner;
import edu.hitsz.c102c.util.TimedTest;

/* loaded from: input_file:edu/hitsz/c102c/cnn/RunCNN.class */
public class RunCNN {
    public static void runCnn() {
        CNN.LayerBuilder layerBuilder = new CNN.LayerBuilder();
        layerBuilder.addLayer(Layer.buildInputLayer(new Layer.Size(28, 28)));
        layerBuilder.addLayer(Layer.buildConvLayer(6, new Layer.Size(5, 5)));
        layerBuilder.addLayer(Layer.buildSampLayer(new Layer.Size(2, 2)));
        layerBuilder.addLayer(Layer.buildConvLayer(12, new Layer.Size(5, 5)));
        layerBuilder.addLayer(Layer.buildSampLayer(new Layer.Size(2, 2)));
        layerBuilder.addLayer(Layer.buildOutputLayer(10));
        CNN cnn = new CNN(layerBuilder, 50);
        Dataset load = Dataset.load("dataset/train.format", ",", 784);
        cnn.train(load, 3);
        cnn.saveModel("model/model.cnn");
        load.clear();
        cnn.predict(Dataset.load("dataset/test.format", ",", -1), "dataset/test.predict");
    }

    public static void main(String[] strArr) {
        new TimedTest(new TimedTest.TestTask() { // from class: edu.hitsz.c102c.cnn.RunCNN.1
            @Override // edu.hitsz.c102c.util.TimedTest.TestTask
            public void process() {
                RunCNN.runCnn();
            }
        }, 1).test();
        ConcurenceRunner.stop();
    }
}
