package org.ea.javacnn.readers;

import java.io.FileInputStream;
import java.nio.ByteBuffer;
import java.util.Arrays;

/* loaded from: input_file:org/ea/javacnn/readers/MnistReader.class */
public class MnistReader implements Reader {
    private String labelFile;
    private String imageFile;
    private FileInputStream labelIO;
    private FileInputStream imageIO;
    private int labelSize;
    private int imageSize;
    private int imageX;
    private int imageY;

    private int readInt(FileInputStream fileInputStream) throws Exception {
        byte[] bArr = new byte[4];
        fileInputStream.read(bArr);
        return ByteBuffer.wrap(bArr).getInt();
    }

    @Override // org.ea.javacnn.readers.Reader
    public int size() {
        return this.imageSize;
    }

    public MnistReader(String str, String str2) {
        try {
            this.labelFile = str;
            this.imageFile = str2;
            this.labelIO = new FileInputStream(str);
            this.imageIO = new FileInputStream(str2);
            if (readInt(this.labelIO) != 2049) {
                throw new Exception("Label file header missing");
            }
            if (readInt(this.imageIO) != 2051) {
                throw new Exception("Image file header missing");
            }
            this.labelSize = readInt(this.labelIO);
            this.imageSize = readInt(this.imageIO);
            if (this.labelSize != this.imageSize) {
                throw new Exception("Labels and images don't match in number.");
            }
            this.imageY = readInt(this.imageIO);
            this.imageX = readInt(this.imageIO);
            System.out.println("LSZ " + this.labelSize + " ISZ " + this.imageSize + " Y " + this.imageY + " X " + this.imageX);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // org.ea.javacnn.readers.Reader
    public int getMaxvalue() {
        return 255;
    }

    @Override // org.ea.javacnn.readers.Reader
    public int numOfClasses() {
        return 10;
    }

    @Override // org.ea.javacnn.readers.Reader
    public void reset() {
        try {
            this.labelIO.close();
            this.imageIO.close();
            this.labelIO = new FileInputStream(this.labelFile);
            this.imageIO = new FileInputStream(this.imageFile);
            readInt(this.labelIO);
            readInt(this.labelIO);
            readInt(this.imageIO);
            readInt(this.imageIO);
            readInt(this.imageIO);
            readInt(this.imageIO);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Override // org.ea.javacnn.readers.Reader
    public int readNextLabel() {
        try {
            return this.labelIO.read();
        } catch (Exception e) {
            e.printStackTrace();
            return -1;
        }
    }

    @Override // org.ea.javacnn.readers.Reader
    public int[] readNextImage() throws Exception {
        int i = this.imageX * this.imageY;
        byte[] bArr = new byte[i];
        Arrays.fill(bArr, (byte) 0);
        this.imageIO.read(bArr, 0, i);
        int[] iArr = new int[i];
        for (int i2 = 0; i2 < i; i2++) {
            iArr[i2] = bArr[i2];
        }
        return iArr;
    }

    public static void main(String[] strArr) {
        MnistReader mnistReader = new MnistReader("mnist/t10k-labels.idx1-ubyte", "mnist/t10k-images.idx3-ubyte");
        for (int i = 0; i < 200; i++) {
            System.out.print(mnistReader.readNextLabel());
        }
        System.out.print(mnistReader.readNextLabel());
        for (int i2 = 0; i2 < 200; i2++) {
            try {
                mnistReader.readNextImage();
            } catch (Exception e) {
                e.printStackTrace();
                System.out.println("Crash at " + i2);
            }
        }
        try {
            int[] readNextImage = mnistReader.readNextImage();
            for (int i3 = 0; i3 < readNextImage.length; i3++) {
                if (i3 % 28 == 0) {
                    System.out.println();
                }
                System.out.print((readNextImage[i3] & 255) + " ");
            }
        } catch (Exception e2) {
            e2.printStackTrace();
        }
    }

    @Override // org.ea.javacnn.readers.Reader
    public int getSizeX() {
        return this.imageX;
    }

    @Override // org.ea.javacnn.readers.Reader
    public int getSizeY() {
        return this.imageY;
    }
}
