package org.neuroph.imgrec;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import javax.imageio.ImageIO;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.Neuron;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.exceptions.VectorSizeMismatchException;
import org.neuroph.imgrec.image.Dimension;
import org.neuroph.imgrec.image.ImageFactory;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.TransferFunctionType;

/* loaded from: input_file:org/neuroph/imgrec/ImageRecognitionHelper.class */
public class ImageRecognitionHelper {
    public static NeuralNetwork createNewNeuralNetwork(String str, Dimension dimension, ColorMode colorMode, List<String> list, List<Integer> list2, TransferFunctionType transferFunctionType) {
        int width = (colorMode == ColorMode.COLOR_RGB || colorMode == ColorMode.COLOR_HSL) ? 3 * dimension.getWidth() * dimension.getHeight() : dimension.getWidth() * dimension.getHeight();
        int size = list.size();
        list2.add(0, Integer.valueOf(width));
        list2.add(Integer.valueOf(size));
        System.out.println("Neuron layer size counts vector = " + list2);
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(list2, transferFunctionType);
        multiLayerPerceptron.setLabel(str);
        multiLayerPerceptron.addPlugin(new ImageRecognitionPlugin(dimension, colorMode));
        assignLabelsToOutputNeurons(multiLayerPerceptron, list);
        multiLayerPerceptron.setLearningRule(new MomentumBackpropagation());
        return multiLayerPerceptron;
    }

    private static void assignLabelsToOutputNeurons(NeuralNetwork neuralNetwork, List<String> list) {
        List<Neuron> outputNeurons = neuralNetwork.getOutputNeurons();
        for (int i = 0; i < outputNeurons.size(); i++) {
            outputNeurons.get(i).setLabel(list.get(i));
        }
    }

    public static DataSet createTrainingSet(List<String> list, Map<String, FractionRgbData> map) {
        return createRGBTrainingSet(list, map);
    }

    public static DataSet createRGBTrainingSet(List<String> list, Map<String, FractionRgbData> map) {
        DataSet dataSet = new DataSet(map.values().iterator().next().getFlattenedRgbValues().length, list.size());
        for (Map.Entry<String, FractionRgbData> entry : map.entrySet()) {
            dataSet.addRow(new DataSetRow(entry.getValue().getFlattenedRgbValues(), createResponse(entry.getKey(), list)));
        }
        int inputSize = dataSet.getInputSize();
        for (int i = 0; i < dataSet.getOutputSize(); i++) {
            dataSet.setColumnName(inputSize + i, list.get(i));
        }
        return dataSet;
    }

    public static DataSet createHSLTrainingSet(List<String> list, Map<String, FractionHSLData> map) {
        DataSet dataSet = new DataSet(map.values().iterator().next().getFlattenedHSLValues().length, list.size());
        for (Map.Entry<String, FractionHSLData> entry : map.entrySet()) {
            dataSet.addRow(new DataSetRow(entry.getValue().getFlattenedHSLValues(), createResponse(entry.getKey(), list)));
        }
        int inputSize = dataSet.getInputSize();
        for (int i = 0; i < dataSet.getOutputSize(); i++) {
            dataSet.setColumnName(inputSize + i, list.get(i));
        }
        return dataSet;
    }

    public static DataSet createBlackAndWhiteTrainingSet(List<String> list, Map<String, FractionRgbData> map) throws VectorSizeMismatchException {
        DataSet dataSet = new DataSet(map.values().iterator().next().getFlattenedRgbValues().length / 3, list.size());
        for (Map.Entry<String, FractionRgbData> entry : map.entrySet()) {
            dataSet.addRow(new DataSetRow(FractionRgbData.convertRgbInputToBinaryBlackAndWhite(entry.getValue().getFlattenedRgbValues()), createResponse(entry.getKey(), list)));
        }
        int inputSize = dataSet.getInputSize();
        for (int i = 0; i < dataSet.getOutputSize(); i++) {
            dataSet.setColumnName(inputSize + i, list.get(i));
        }
        return dataSet;
    }

    public static Map<String, FractionRgbData> getFractionRgbDataForDirectory(File file, Dimension dimension) throws IOException {
        if (!file.isDirectory()) {
            throw new IllegalArgumentException("The given file must be a directory.  Argument is: " + file);
        }
        HashMap hashMap = new HashMap();
        ImageFilesIterator imageFilesIterator = new ImageFilesIterator(file);
        while (imageFilesIterator.hasNext()) {
            hashMap.put(new StringTokenizer(imageFilesIterator.getFilenameOfCurrentImage(), ".").nextToken(), new FractionRgbData(ImageSampler.downSampleImage(dimension, ImageFactory.getImage(imageFilesIterator.next()))));
        }
        return hashMap;
    }

    public static Map<String, FractionHSLData> getFractionHSLDataForDirectory(File file, Dimension dimension) throws IOException {
        if (!file.isDirectory()) {
            throw new IllegalArgumentException("The given file must be a directory.  Argument is: " + file);
        }
        HashMap hashMap = new HashMap();
        ImageFilesIterator imageFilesIterator = new ImageFilesIterator(file);
        while (imageFilesIterator.hasNext()) {
            try {
                File next = imageFilesIterator.next();
                hashMap.put(new StringTokenizer(next.getName(), ".").nextToken(), new FractionHSLData(ImageUtilities.resizeImage(ImageIO.read(next), dimension.getWidth(), dimension.getHeight())));
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return hashMap;
    }

    private static double[] createResponse(String str, List<String> list) {
        double[] dArr = new double[list.size()];
        int i = 0;
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            if (str.startsWith(it.next())) {
                dArr[i] = 1.0d;
            } else {
                dArr[i] = 0.0d;
            }
            i++;
        }
        return dArr;
    }

    private static List<String> createLabels(HashMap<String, ?> hashMap) {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            String nextToken = new StringTokenizer(it.next(), "._").nextToken();
            if (!arrayList.contains(nextToken)) {
                arrayList.add(nextToken);
            }
        }
        Collections.sort(arrayList);
        return arrayList;
    }

    public static DataSet createImageDataSetFromFile(String str, List<String> list, String str2, ColorMode colorMode, Dimension dimension, String str3, int i) {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        try {
            File file = new File(str);
            if (colorMode == ColorMode.COLOR_HSL) {
                hashMap2.putAll(getFractionHSLDataForDirectory(file, dimension));
                if (list == null) {
                    list = createLabels(hashMap2);
                }
            } else {
                hashMap.putAll(getFractionRgbDataForDirectory(file, dimension));
                if (list == null) {
                    list = createLabels(hashMap);
                }
            }
        } catch (IOException e) {
            System.err.println("Unable to load images from labeled images dir: '" + str + "'");
            System.err.println(e.toString());
        }
        if (str2 != null && !str2.equals("")) {
            try {
                File file2 = new File(str2);
                if (colorMode == ColorMode.COLOR_HSL) {
                    hashMap2.putAll(getFractionHSLDataForDirectory(file2, dimension));
                } else {
                    hashMap.putAll(getFractionRgbDataForDirectory(file2, dimension));
                }
            } catch (IOException e2) {
                System.err.println("Unable to load images from junk images dir: '" + str2 + "'");
                System.err.println(e2.toString());
            }
        }
        DataSet createRGBTrainingSet = colorMode == ColorMode.COLOR_RGB ? createRGBTrainingSet(list, hashMap) : colorMode == ColorMode.COLOR_HSL ? createHSLTrainingSet(list, hashMap2) : createBlackAndWhiteTrainingSet(list, hashMap);
        createRGBTrainingSet.setLabel(str3);
        createRGBTrainingSet.setColumnNames((String[]) list.toArray());
        createRGBTrainingSet.save("D:\\Doktorske\\Beograd\\Neuronske mreze - Zoran Sevarac\\Cifar 10\\dataset.tset");
        return createRGBTrainingSet;
    }
}
