/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.segmentChallenge;

import java.math.BigDecimal;
import java.math.RoundingMode;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.data.norm.MaxNormalizer;

public class SegmentChallengeSample
implements LearningEventListener {
    public int[] count = new int[8];
    public int[] correct = new int[8];
    int unpredicted = 0;

    public static void main(String[] args) {
        new SegmentChallengeSample().run();
    }

    public void run() {
        System.out.println("Creating training and test set from file...");
        String trainingSetFileName = "data_sets/segment challenge.txt";
        String testSetFileName = "data_sets/segment test.txt";
        int inputsCount = 19;
        int outputsCount = 7;
        DataSet trainingSet = DataSet.createFromFile(trainingSetFileName, inputsCount, outputsCount, ",");
        System.out.println("Training set size: " + trainingSet.getRows().size());
        trainingSet.shuffle();
        trainingSet.shuffle();
        MaxNormalizer normalizer = new MaxNormalizer();
        normalizer.normalize(trainingSet);
        DataSet testSet = DataSet.createFromFile(testSetFileName, inputsCount, outputsCount, ",");
        System.out.println("Test set size: " + testSet.getRows().size());
        System.out.println("--------------------------------------------------");
        testSet.shuffle();
        testSet.shuffle();
        normalizer.normalize(testSet);
        System.out.println("Creating neural network...");
        MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(inputsCount, 17, 10, outputsCount);
        MomentumBackpropagation learningRule = (MomentumBackpropagation)neuralNet.getLearningRule();
        learningRule.addListener(this);
        learningRule.setLearningRate(0.01);
        learningRule.setMaxError(0.001);
        learningRule.setMaxIterations(12000);
        System.out.println("Training network...");
        neuralNet.learn(trainingSet);
        System.out.println("Testing network...\n\n");
        this.testNeuralNetwork(neuralNet, testSet);
        System.out.println("Done.");
        System.out.println("**************************************************");
    }

    public void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {
        System.out.println("**************************************************");
        System.out.println("**********************RESULT**********************");
        System.out.println("**************************************************");
        for (DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();
            int predicted = SegmentChallengeSample.maxOutput(networkOutput);
            double[] networkDesiredOutput = testSetRow.getDesiredOutput();
            int ideal = SegmentChallengeSample.maxOutput(networkDesiredOutput);
            this.keepScore(predicted, ideal);
        }
        System.out.println("Total cases: " + this.count[7] + ". ");
        System.out.println("Correctly predicted cases: " + this.correct[7] + ". ");
        System.out.println("Incorrectly predicted cases: " + (this.count[7] - this.correct[7] - this.unpredicted) + ". ");
        System.out.println("Unrecognized cases: " + this.unpredicted + ". ");
        double percentTotal = (double)this.correct[7] * 100.0 / (double)this.count[7];
        System.out.println("Predicted correctly: " + this.formatDecimalNumber(percentTotal) + "%. ");
        for (int i = 0; i < this.correct.length - 1; ++i) {
            double p = (double)this.correct[i] * 100.0 / (double)this.count[i];
            System.out.println("Segment class: " + this.getClasificationClass(i + 1) + " - Correct/total: " + this.correct[i] + "/" + this.count[i] + "(" + this.formatDecimalNumber(p) + "%). ");
        }
        this.count = new int[8];
        this.correct = new int[8];
        this.unpredicted = 0;
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType().equals((Object)LearningEvent.Type.LEARNING_STOPPED)) {
            double error = bp.getTotalNetworkError();
            System.out.println("Training completed in " + bp.getCurrentIteration() + " iterations, ");
            System.out.println("With total error: " + this.formatDecimalNumber(error));
        } else {
            System.out.println("Iteration: " + bp.getCurrentIteration() + " | Network error: " + bp.getTotalNetworkError());
        }
    }

    public static int maxOutput(double[] array) {
        double max = array[0];
        int index = 0;
        for (int i = 0; i < array.length; ++i) {
            if (!(array[i] > max)) continue;
            index = i;
            max = array[i];
        }
        if (max < 0.5) {
            return -1;
        }
        return index;
    }

    public void keepScore(int prediction, int ideal) {
        int n = ideal;
        this.count[n] = this.count[n] + 1;
        this.count[7] = this.count[7] + 1;
        if (prediction == ideal) {
            int n2 = ideal;
            this.correct[n2] = this.correct[n2] + 1;
            this.correct[7] = this.correct[7] + 1;
        }
        if (prediction == -1) {
            ++this.unpredicted;
        }
    }

    public String formatDecimalNumber(double number) {
        return new BigDecimal(number).setScale(4, RoundingMode.HALF_UP).toString();
    }

    public String getClasificationClass(int i) {
        switch (i) {
            case 1: {
                return "brickface";
            }
            case 2: {
                return "sky";
            }
            case 3: {
                return "foliage";
            }
            case 4: {
                return "cement";
            }
            case 5: {
                return "window";
            }
            case 6: {
                return "path";
            }
            case 7: {
                return "grass";
            }
        }
        return "error";
    }
}

