package org.neuroph.samples.eval;

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.ErrorEvaluator;
import org.neuroph.eval.Evaluation;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.neuroph.nnet.MultiLayerPerceptron;

/* loaded from: input_file:org/neuroph/samples/eval/ClassifierEvaluationSample.class */
public class ClassifierEvaluationSample {
    public static void main(String[] strArr) {
        Evaluation evaluation = new Evaluation();
        evaluation.addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
        MultiLayerPerceptron multiLayerPerceptron = (MultiLayerPerceptron) NeuralNetwork.createFromFile("irisNet.nnet");
        DataSet createFromFile = DataSet.createFromFile("data_sets/iris_data_normalised.txt", 4, 3, ",");
        evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(new String[]{"Virginica", "Setosa", "Versicolor"}));
        evaluation.evaluateDataSet(multiLayerPerceptron, createFromFile);
        ConfusionMatrix result = ((ClassifierEvaluator) evaluation.getEvaluator(ClassifierEvaluator.MultiClass.class)).getResult();
        System.out.println("Confusion matrrix:\r\n");
        System.out.println(result.toString() + "\r\n\r\n");
        System.out.println("Classification metrics\r\n");
        ClassificationMetrics[] createFromMatrix = ClassificationMetrics.createFromMatrix(result);
        ClassificationMetrics.Stats average = ClassificationMetrics.average(createFromMatrix);
        for (ClassificationMetrics classificationMetrics : createFromMatrix) {
            System.out.println(classificationMetrics.toString() + "\r\n");
        }
        System.out.println(average.toString());
    }
}
