/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.imgrec;

import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import javax.imageio.ImageIO;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.Neuron;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.exceptions.VectorSizeMismatchException;
import org.neuroph.imgrec.ColorMode;
import org.neuroph.imgrec.FractionHSLData;
import org.neuroph.imgrec.FractionRgbData;
import org.neuroph.imgrec.ImageFilesIterator;
import org.neuroph.imgrec.ImageRecognitionPlugin;
import org.neuroph.imgrec.ImageSampler;
import org.neuroph.imgrec.ImageUtilities;
import org.neuroph.imgrec.image.Dimension;
import org.neuroph.imgrec.image.Image;
import org.neuroph.imgrec.image.ImageFactory;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;

public class ImageRecognitionHelper {
    public static NeuralNetwork createNewNeuralNetwork(String label, Dimension samplingResolution, ColorMode colorMode, List<String> imageLabels, List<Integer> layersNeuronsCount, TransferFunctionType transferFunctionType) {
        int numberOfInputNeurons = colorMode == ColorMode.COLOR_RGB || colorMode == ColorMode.COLOR_HSL ? 3 * samplingResolution.getWidth() * samplingResolution.getHeight() : samplingResolution.getWidth() * samplingResolution.getHeight();
        int numberOfOuputNeurons = imageLabels.size();
        layersNeuronsCount.add(0, numberOfInputNeurons);
        layersNeuronsCount.add(numberOfOuputNeurons);
        System.out.println("Neuron layer size counts vector = " + layersNeuronsCount);
        MultiLayerPerceptron neuralNetwork = new MultiLayerPerceptron(layersNeuronsCount, transferFunctionType);
        neuralNetwork.setLabel(label);
        ImageRecognitionPlugin imageRecognitionPlugin = new ImageRecognitionPlugin(samplingResolution, colorMode);
        neuralNetwork.addPlugin(imageRecognitionPlugin);
        ImageRecognitionHelper.assignLabelsToOutputNeurons(neuralNetwork, imageLabels);
        neuralNetwork.setLearningRule(new MomentumBackpropagation());
        return neuralNetwork;
    }

    private static void assignLabelsToOutputNeurons(NeuralNetwork neuralNetwork, List<String> imageLabels) {
        List<Neuron> outputNeurons = neuralNetwork.getOutputNeurons();
        for (int i = 0; i < outputNeurons.size(); ++i) {
            Neuron neuron = outputNeurons.get(i);
            String label = imageLabels.get(i);
            neuron.setLabel(label);
        }
    }

    public static DataSet createTrainingSet(List<String> imageLabels, Map<String, FractionRgbData> rgbDataMap) {
        return ImageRecognitionHelper.createRGBTrainingSet(imageLabels, rgbDataMap);
    }

    public static DataSet createRGBTrainingSet(List<String> imageLabels, Map<String, FractionRgbData> rgbDataMap) {
        int inputCount = rgbDataMap.values().iterator().next().getFlattenedRgbValues().length;
        int outputCount = imageLabels.size();
        DataSet trainingSet = new DataSet(inputCount, outputCount);
        for (Map.Entry<String, FractionRgbData> entry : rgbDataMap.entrySet()) {
            double[] input = entry.getValue().getFlattenedRgbValues();
            double[] response = ImageRecognitionHelper.createResponse(entry.getKey(), imageLabels);
            trainingSet.addRow(new DataSetRow(input, response));
        }
        int inputSize = trainingSet.getInputSize();
        for (int c = 0; c < trainingSet.getOutputSize(); ++c) {
            trainingSet.setColumnName(inputSize + c, imageLabels.get(c));
        }
        return trainingSet;
    }

    public static DataSet createHSLTrainingSet(List<String> imageLabels, Map<String, FractionHSLData> hslDataMap) {
        int inputCount = hslDataMap.values().iterator().next().getFlattenedHSLValues().length;
        int outputCount = imageLabels.size();
        DataSet trainingSet = new DataSet(inputCount, outputCount);
        for (Map.Entry<String, FractionHSLData> entry : hslDataMap.entrySet()) {
            double[] input = entry.getValue().getFlattenedHSLValues();
            double[] response = ImageRecognitionHelper.createResponse(entry.getKey(), imageLabels);
            trainingSet.addRow(new DataSetRow(input, response));
        }
        int inputSize = trainingSet.getInputSize();
        for (int c = 0; c < trainingSet.getOutputSize(); ++c) {
            trainingSet.setColumnName(inputSize + c, imageLabels.get(c));
        }
        return trainingSet;
    }

    public static DataSet createBlackAndWhiteTrainingSet(List<String> imageLabels, Map<String, FractionRgbData> rgbDataMap) throws VectorSizeMismatchException {
        int inputCount = rgbDataMap.values().iterator().next().getFlattenedRgbValues().length / 3;
        int outputCount = imageLabels.size();
        DataSet trainingSet = new DataSet(inputCount, outputCount);
        for (Map.Entry<String, FractionRgbData> entry : rgbDataMap.entrySet()) {
            double[] inputRGB = entry.getValue().getFlattenedRgbValues();
            double[] inputBW = FractionRgbData.convertRgbInputToBinaryBlackAndWhite(inputRGB);
            double[] response = ImageRecognitionHelper.createResponse(entry.getKey(), imageLabels);
            trainingSet.addRow(new DataSetRow(inputBW, response));
        }
        int inputSize = trainingSet.getInputSize();
        for (int c = 0; c < trainingSet.getOutputSize(); ++c) {
            trainingSet.setColumnName(inputSize + c, imageLabels.get(c));
        }
        return trainingSet;
    }

    public static Map<String, FractionRgbData> getFractionRgbDataForDirectory(File imgDir, Dimension samplingResolution) throws IOException {
        if (!imgDir.isDirectory()) {
            throw new IllegalArgumentException("The given file must be a directory.  Argument is: " + imgDir);
        }
        HashMap<String, FractionRgbData> rgbDataMap = new HashMap<String, FractionRgbData>();
        ImageFilesIterator imagesIterator = new ImageFilesIterator(imgDir);
        while (imagesIterator.hasNext()) {
            File imgFile = imagesIterator.next();
            Image img = ImageFactory.getImage(imgFile);
            img = ImageSampler.downSampleImage(samplingResolution, img);
            String filenameOfCurrentImage = imagesIterator.getFilenameOfCurrentImage();
            StringTokenizer st = new StringTokenizer(filenameOfCurrentImage, ".");
            rgbDataMap.put(st.nextToken(), new FractionRgbData(img));
        }
        return rgbDataMap;
    }

    public static Map<String, FractionHSLData> getFractionHSLDataForDirectory(File imgDir, Dimension samplingResolution) throws IOException {
        if (!imgDir.isDirectory()) {
            throw new IllegalArgumentException("The given file must be a directory.  Argument is: " + imgDir);
        }
        HashMap<String, FractionHSLData> map = new HashMap<String, FractionHSLData>();
        ImageFilesIterator imagesIterator = new ImageFilesIterator(imgDir);
        try {
            while (imagesIterator.hasNext()) {
                File imgFile = imagesIterator.next();
                BufferedImage img = ImageIO.read(imgFile);
                BufferedImage image = ImageUtilities.resizeImage(img, samplingResolution.getWidth(), samplingResolution.getHeight());
                String filenameOfCurrentImage = imgFile.getName();
                StringTokenizer st = new StringTokenizer(filenameOfCurrentImage, ".");
                map.put(st.nextToken(), new FractionHSLData(image));
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return map;
    }

    private static double[] createResponse(String inputLabel, List<String> imageLabels) {
        double[] response = new double[imageLabels.size()];
        int i = 0;
        for (String imageLabel : imageLabels) {
            response[i] = inputLabel.startsWith(imageLabel) ? 1.0 : 0.0;
            ++i;
        }
        return response;
    }

    private static List<String> createLabels(HashMap<String, ?> map) {
        ArrayList<String> imageLabels = new ArrayList<String>();
        for (String imgName : map.keySet()) {
            StringTokenizer st = new StringTokenizer(imgName, "._");
            String imageLabel = st.nextToken();
            if (imageLabels.contains(imageLabel)) continue;
            imageLabels.add(imageLabel);
        }
        Collections.sort(imageLabels);
        return imageLabels;
    }

    public static DataSet createImageDataSetFromFile(String imageDir, List<String> imageLabels, String junkDir, ColorMode colorMode, Dimension samplingResolution, String trainingSetName, int numOfPictures) {
        DataSet dataSet = null;
        HashMap<String, FractionRgbData> rgbDataMap = new HashMap<String, FractionRgbData>();
        HashMap<String, FractionHSLData> hslDataMap = new HashMap<String, FractionHSLData>();
        try {
            File labeledImagesDir = new File(imageDir);
            if (colorMode == ColorMode.COLOR_HSL) {
                hslDataMap.putAll(ImageRecognitionHelper.getFractionHSLDataForDirectory(labeledImagesDir, samplingResolution));
                if (imageLabels == null) {
                    imageLabels = ImageRecognitionHelper.createLabels(hslDataMap);
                }
            } else {
                rgbDataMap.putAll(ImageRecognitionHelper.getFractionRgbDataForDirectory(labeledImagesDir, samplingResolution));
                if (imageLabels == null) {
                    imageLabels = ImageRecognitionHelper.createLabels(rgbDataMap);
                }
            }
        }
        catch (IOException ioe) {
            System.err.println("Unable to load images from labeled images dir: '" + imageDir + "'");
            System.err.println(ioe.toString());
        }
        if (junkDir != null && !junkDir.equals("")) {
            try {
                File junkImagesDir = new File(junkDir);
                if (colorMode == ColorMode.COLOR_HSL) {
                    hslDataMap.putAll(ImageRecognitionHelper.getFractionHSLDataForDirectory(junkImagesDir, samplingResolution));
                } else {
                    rgbDataMap.putAll(ImageRecognitionHelper.getFractionRgbDataForDirectory(junkImagesDir, samplingResolution));
                }
            }
            catch (IOException ioe) {
                System.err.println("Unable to load images from junk images dir: '" + junkDir + "'");
                System.err.println(ioe.toString());
            }
        }
        dataSet = colorMode == ColorMode.COLOR_RGB ? ImageRecognitionHelper.createRGBTrainingSet(imageLabels, rgbDataMap) : (colorMode == ColorMode.COLOR_HSL ? ImageRecognitionHelper.createHSLTrainingSet(imageLabels, hslDataMap) : ImageRecognitionHelper.createBlackAndWhiteTrainingSet(imageLabels, rgbDataMap));
        dataSet.setLabel(trainingSetName);
        dataSet.setColumnNames((String[])imageLabels.toArray());
        dataSet.save("D:\\Doktorske\\Beograd\\Neuronske mreze - Zoran Sevarac\\Cifar 10\\dataset.tset");
        return dataSet;
    }
}

