/*
 * Decompiled with CFR 0.152.
 */
package org.encog.util.data;

import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import org.encog.EncogError;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.data.basic.BasicMLDataPair;
import org.encog.ml.data.basic.BasicMLDataSet;

public class MNISTReader {
    private final int numLabels;
    private final int numImages;
    private final int numRows;
    private final int numCols;
    private final MLDataSet data;

    public MNISTReader(String labelFilename, String imageFilename) {
        try {
            DataInputStream labels = new DataInputStream(new FileInputStream(labelFilename));
            DataInputStream images = new DataInputStream(new FileInputStream(imageFilename));
            int magicNumber = labels.readInt();
            if (magicNumber != 2049) {
                throw new EncogError("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
            }
            magicNumber = images.readInt();
            if (magicNumber != 2051) {
                throw new EncogError("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
            }
            this.numLabels = labels.readInt();
            this.numImages = images.readInt();
            this.numRows = images.readInt();
            this.numCols = images.readInt();
            if (this.numLabels != this.numImages) {
                StringBuilder str = new StringBuilder();
                str.append("Image file and label file do not contain the same number of entries.\n");
                str.append("  Label file contains: " + this.numLabels + "\n");
                str.append("  Image file contains: " + this.numImages + "\n");
                throw new EncogError(str.toString());
            }
            byte[] labelsData = new byte[this.numLabels];
            labels.read(labelsData);
            int imageVectorSize = this.numCols * this.numRows;
            byte[] imagesData = new byte[this.numLabels * imageVectorSize];
            images.read(imagesData);
            this.data = new BasicMLDataSet();
            int imageIndex = 0;
            int i = 0;
            while (i < this.numLabels) {
                byte label = labelsData[i];
                BasicMLData inputData = new BasicMLData(imageVectorSize);
                int j = 0;
                while (j < imageVectorSize) {
                    inputData.setData(j, (double)(imagesData[imageIndex++] & 0xFF) / 255.0);
                    ++j;
                }
                BasicMLData idealData = new BasicMLData(10);
                idealData.setData(label, 1.0);
                this.data.add(new BasicMLDataPair(inputData, idealData));
                ++i;
            }
            images.close();
            labels.close();
        }
        catch (IOException ex) {
            throw new EncogError(ex);
        }
    }

    public int getNumLabels() {
        return this.numLabels;
    }

    public int getNumImages() {
        return this.numImages;
    }

    public int getNumRows() {
        return this.numRows;
    }

    public int getNumCols() {
        return this.numCols;
    }

    public MLDataSet getData() {
        return this.data;
    }
}

