/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.eval;

import java.util.HashMap;
import java.util.Map;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.ErrorEvaluator;
import org.neuroph.eval.EvaluationResult;
import org.neuroph.eval.Evaluator;
import org.neuroph.eval.classification.ClassificationMetrics;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class Evaluation {
    private static final Logger LOGGER = LoggerFactory.getLogger((String)"neuroph");
    private final Map<Class<?>, Evaluator> evaluators = new HashMap();

    public Evaluation() {
        this.addEvaluator(new ErrorEvaluator(new MeanSquaredError()));
    }

    public EvaluationResult evaluateDataSet(NeuralNetwork neuralNetwork, DataSet dataSet) {
        for (Evaluator evaluator : this.evaluators.values()) {
            evaluator.reset();
        }
        for (DataSetRow dataRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataRow.getInput());
            neuralNetwork.calculate();
            for (Evaluator evaluator : this.evaluators.values()) {
                evaluator.processNetworkResult(neuralNetwork.getOutput(), dataRow.getDesiredOutput());
            }
        }
        Object confusionMatrix = neuralNetwork.getOutputsCount() > 1 ? this.getEvaluator(ClassifierEvaluator.MultiClass.class).getResult() : this.getEvaluator(ClassifierEvaluator.Binary.class).getResult();
        double meanSquaredError = this.getEvaluator(ErrorEvaluator.class).getResult();
        EvaluationResult result = new EvaluationResult();
        result.setDataSet(dataSet);
        result.setConfusionMatrix((ConfusionMatrix)confusionMatrix);
        result.setMeanSquareError(meanSquaredError);
        return result;
    }

    public void addEvaluator(Evaluator evaluator) {
        if (evaluator == null) {
            throw new IllegalArgumentException("Evaluator cannot be null!");
        }
        this.evaluators.put(evaluator.getClass(), evaluator);
    }

    public <T extends Evaluator> T getEvaluator(Class<T> type) {
        return (T)((Evaluator)type.cast(this.evaluators.get(type)));
    }

    public Map<Class<?>, Evaluator> getEvaluators() {
        return this.evaluators;
    }

    public double getMeanSquareError() {
        return this.getEvaluator(ErrorEvaluator.class).getResult();
    }

    public static void runFullEvaluation(NeuralNetwork<?> neuralNet, DataSet dataSet) {
        ClassificationMetrics[] metrics;
        Evaluation evaluation = new Evaluation();
        evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(dataSet.getColumnNames()));
        evaluation.evaluateDataSet(neuralNet, dataSet);
        LOGGER.info("##############################################################################");
        LOGGER.info("MeanSquare Error: " + evaluation.getEvaluator(ErrorEvaluator.class).getResult());
        LOGGER.info("##############################################################################");
        ClassifierEvaluator classificationEvaluator = evaluation.getEvaluator(ClassifierEvaluator.MultiClass.class);
        ConfusionMatrix confusionMatrix = classificationEvaluator.getResult();
        LOGGER.info("Confusion Matrix: \r\n" + confusionMatrix.toString());
        LOGGER.info("##############################################################################");
        LOGGER.info("Classification metrics: ");
        for (ClassificationMetrics cm : metrics = ClassificationMetrics.createFromMatrix(confusionMatrix)) {
            LOGGER.info(cm.toString());
        }
        LOGGER.info("##############################################################################");
    }
}

