package org.neuroph.samples.forestCover;

import java.math.BigDecimal;
import java.math.RoundingMode;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.nnet.MultiLayerPerceptron;

/* loaded from: input_file:org/neuroph/samples/forestCover/Evaluate.class */
public class Evaluate {
    private Config config;
    int[] count = new int[8];
    int[] correct = new int[8];
    int unpredicted = 0;

    public Evaluate(Config config) {
        this.config = config;
    }

    public void evaluate() {
        System.out.println("Evaluating neural network...");
        testNeuralNetwork((MultiLayerPerceptron) NeuralNetwork.createFromFile(this.config.getTrainedNetworkFileName()), DataSet.load(this.config.getTestFileName()));
    }

    public void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            keepScore(maxOutput(neuralNetwork.getOutput()), maxOutput(dataSetRow.getDesiredOutput()));
        }
        System.out.println("Total cases: " + this.count[7] + ". ");
        System.out.println("Correct cases: " + this.correct[7] + ". ");
        System.out.println("Incorrectly predicted cases: " + ((this.count[7] - this.correct[7]) - this.unpredicted) + ". ");
        System.out.println("Unrecognized cases: " + this.unpredicted + ". ");
        System.out.println("Predicted correctly: " + formatDecimalNumber((this.correct[7] * 100.0d) / this.count[7]) + "%. ");
        for (int i = 0; i < this.correct.length - 1; i++) {
            System.out.println("Tree type: " + (i + 1) + " - Correct/total: " + this.correct[i] + "/" + this.count[i] + "(" + formatDecimalNumber((this.correct[i] * 100.0d) / this.count[i]) + "%). ");
        }
    }

    public static int maxOutput(double[] dArr) {
        double d = dArr[0];
        int i = 0;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                i = i2;
                d = dArr[i2];
            }
        }
        if (d < 0.5d) {
            return -1;
        }
        return i;
    }

    public void keepScore(int i, int i2) {
        int[] iArr = this.count;
        iArr[i2] = iArr[i2] + 1;
        int[] iArr2 = this.count;
        iArr2[7] = iArr2[7] + 1;
        if (i == i2) {
            int[] iArr3 = this.correct;
            iArr3[i2] = iArr3[i2] + 1;
            int[] iArr4 = this.correct;
            iArr4[7] = iArr4[7] + 1;
        }
        if (i == -1) {
            this.unpredicted++;
        }
    }

    private String formatDecimalNumber(double d) {
        return new BigDecimal(d).setScale(2, RoundingMode.HALF_UP).toString();
    }
}
