/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.forestCover;

import java.math.BigDecimal;
import java.math.RoundingMode;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.samples.forestCover.Config;

public class Evaluate {
    private Config config;
    int[] count = new int[8];
    int[] correct = new int[8];
    int unpredicted = 0;

    public Evaluate(Config config) {
        this.config = config;
    }

    public void evaluate() {
        System.out.println("Evaluating neural network...");
        MultiLayerPerceptron neuralNet = (MultiLayerPerceptron)NeuralNetwork.createFromFile(this.config.getTrainedNetworkFileName());
        DataSet dataSet = DataSet.load(this.config.getTestFileName());
        this.testNeuralNetwork(neuralNet, dataSet);
    }

    public void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {
        for (DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();
            int predicted = Evaluate.maxOutput(networkOutput);
            double[] networkDesiredOutput = testSetRow.getDesiredOutput();
            int ideal = Evaluate.maxOutput(networkDesiredOutput);
            this.keepScore(predicted, ideal);
        }
        System.out.println("Total cases: " + this.count[7] + ". ");
        System.out.println("Correct cases: " + this.correct[7] + ". ");
        System.out.println("Incorrectly predicted cases: " + (this.count[7] - this.correct[7] - this.unpredicted) + ". ");
        System.out.println("Unrecognized cases: " + this.unpredicted + ". ");
        double percentTotal = (double)this.correct[7] * 100.0 / (double)this.count[7];
        System.out.println("Predicted correctly: " + this.formatDecimalNumber(percentTotal) + "%. ");
        for (int i = 0; i < this.correct.length - 1; ++i) {
            double p = (double)this.correct[i] * 100.0 / (double)this.count[i];
            System.out.println("Tree type: " + (i + 1) + " - Correct/total: " + this.correct[i] + "/" + this.count[i] + "(" + this.formatDecimalNumber(p) + "%). ");
        }
    }

    public static int maxOutput(double[] array) {
        double max = array[0];
        int index = 0;
        for (int i = 0; i < array.length; ++i) {
            if (!(array[i] > max)) continue;
            index = i;
            max = array[i];
        }
        if (max < 0.5) {
            return -1;
        }
        return index;
    }

    public void keepScore(int actual, int ideal) {
        int n = ideal;
        this.count[n] = this.count[n] + 1;
        this.count[7] = this.count[7] + 1;
        if (actual == ideal) {
            int n2 = ideal;
            this.correct[n2] = this.correct[n2] + 1;
            this.correct[7] = this.correct[7] + 1;
        }
        if (actual == -1) {
            ++this.unpredicted;
        }
    }

    private String formatDecimalNumber(double number) {
        return new BigDecimal(number).setScale(2, RoundingMode.HALF_UP).toString();
    }
}

