package org.ea.javacnn;

import java.util.ArrayList;
import java.util.Arrays;
import org.ea.javacnn.data.DataBlock;
import org.ea.javacnn.data.OutputDefinition;
import org.ea.javacnn.layers.ConvolutionLayer;
import org.ea.javacnn.layers.DropoutLayer;
import org.ea.javacnn.layers.FullyConnectedLayer;
import org.ea.javacnn.layers.InputLayer;
import org.ea.javacnn.layers.LocalResponseNormalizationLayer;
import org.ea.javacnn.layers.PoolingLayer;
import org.ea.javacnn.layers.RectifiedLinearUnitsLayer;
import org.ea.javacnn.losslayers.SoftMaxLayer;
import org.ea.javacnn.readers.PGMReader;
import org.ea.javacnn.trainers.AdaGradTrainer;

/* loaded from: input_file:org/ea/javacnn/MnistTest.class */
public class MnistTest {
    public static void main(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        OutputDefinition outputDefinition = new OutputDefinition();
        PGMReader pGMReader = new PGMReader("pgmfiles/train");
        arrayList.add(new InputLayer(outputDefinition, pGMReader.getSizeX(), pGMReader.getSizeY(), 1));
        arrayList.add(new ConvolutionLayer(outputDefinition, 5, 32, 1, 2));
        arrayList.add(new RectifiedLinearUnitsLayer());
        arrayList.add(new PoolingLayer(outputDefinition, 2, 2, 0));
        arrayList.add(new ConvolutionLayer(outputDefinition, 5, 64, 1, 2));
        arrayList.add(new RectifiedLinearUnitsLayer());
        arrayList.add(new PoolingLayer(outputDefinition, 2, 2, 0));
        arrayList.add(new FullyConnectedLayer(outputDefinition, 1024));
        arrayList.add(new LocalResponseNormalizationLayer());
        arrayList.add(new DropoutLayer(outputDefinition));
        arrayList.add(new FullyConnectedLayer(outputDefinition, pGMReader.numOfClasses()));
        arrayList.add(new SoftMaxLayer(outputDefinition));
        JavaCNN javaCNN = new JavaCNN(arrayList);
        AdaGradTrainer adaGradTrainer = new AdaGradTrainer(javaCNN, 20, 0.001f);
        PGMReader pGMReader2 = new PGMReader("pgmfiles/test");
        try {
            long currentTimeMillis = System.currentTimeMillis();
            int[] iArr = new int[10];
            int[] iArr2 = new int[10];
            DataBlock dataBlock = new DataBlock(pGMReader.getSizeX(), pGMReader.getSizeY(), 1, 0.0d);
            for (int i = 1; i < 501; i++) {
                double d = 0.0d;
                for (int i2 = 0; i2 < pGMReader.size(); i2++) {
                    dataBlock.addImageData(pGMReader.readNextImage(), pGMReader.getMaxvalue());
                    d += adaGradTrainer.train(dataBlock, pGMReader.readNextLabel()).getLoss();
                    if (i2 != 0 && i2 % 1000 == 0) {
                        System.out.println("Pass " + i + " Read images: " + i2);
                        System.out.println("Training time: " + (System.currentTimeMillis() - currentTimeMillis));
                        System.out.println("Loss: " + (d / i2));
                        currentTimeMillis = System.currentTimeMillis();
                    }
                }
                System.out.println("Loss: " + (d / 60000.0d));
                pGMReader.reset();
                if (i != 1) {
                    System.out.println("Last run:");
                    System.out.println("=================================");
                    printPredictions(iArr2, iArr, pGMReader2.size(), pGMReader2.numOfClasses());
                }
                long currentTimeMillis2 = System.currentTimeMillis();
                Arrays.fill(iArr2, 0);
                Arrays.fill(iArr, 0);
                for (int i3 = 0; i3 < pGMReader2.size(); i3++) {
                    dataBlock.addImageData(pGMReader2.readNextImage(), pGMReader.getMaxvalue());
                    javaCNN.forward(dataBlock, false);
                    int readNextLabel = pGMReader2.readNextLabel();
                    if (readNextLabel == javaCNN.getPrediction()) {
                        iArr2[readNextLabel] = iArr2[readNextLabel] + 1;
                    }
                    iArr[readNextLabel] = iArr[readNextLabel] + 1;
                }
                pGMReader2.reset();
                System.out.println("Testing time: " + (System.currentTimeMillis() - currentTimeMillis2));
                System.out.println("Current run:");
                System.out.println("=================================");
                printPredictions(iArr2, iArr, pGMReader2.size(), pGMReader2.numOfClasses());
                currentTimeMillis = System.currentTimeMillis();
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private static void printPredictions(int[] iArr, int[] iArr2, int i, int i2) {
        int i3 = 0;
        for (int i4 = 0; i4 < i2; i4++) {
            System.out.println("Number " + i4 + " has predictions " + iArr[i4] + "/" + iArr2[i4] + "\t\t" + ((iArr[i4] / iArr2[i4]) * 100.0f) + "%");
            i3 += iArr[i4];
        }
        System.out.println("Total correct predictions " + i3 + "/" + i + "\t\t" + ((i3 / i) * 100.0f) + "%");
    }
}
