package org.neuroph.contrib.bpbench;

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.eval.ClassifierEvaluator;
import org.neuroph.eval.Evaluation;
import org.neuroph.eval.classification.ConfusionMatrix;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.util.TransferFunctionType;

/* loaded from: input_file:org/neuroph/contrib/bpbench/AbstractTraining.class */
public abstract class AbstractTraining {
    private final NeuralNetwork neuralNet;
    private final DataSet dataset;
    private TrainingStatistics stats = new TrainingStatistics();
    private TrainingSettings settings;

    public abstract void testNeuralNet();

    public abstract LearningRule setParameters();

    public ConfusionMatrix createMatrix() {
        Evaluation evaluation = new Evaluation();
        String[] strArr = new String[this.dataset.getOutputSize()];
        for (int i = 0; i < this.dataset.getOutputSize(); i++) {
            strArr[i] = this.dataset.getColumnName(this.dataset.getInputSize() + i);
        }
        evaluation.addEvaluator(new ClassifierEvaluator.MultiClass(strArr));
        return evaluation.evaluateDataSet(this.neuralNet, this.dataset).getConfusionMatrix();
    }

    public AbstractTraining(NeuralNetwork neuralNetwork, DataSet dataSet, TrainingSettings trainingSettings) {
        this.neuralNet = neuralNetwork;
        this.dataset = dataSet;
        this.settings = trainingSettings;
    }

    public AbstractTraining(DataSet dataSet, TrainingSettings trainingSettings) {
        this.dataset = dataSet;
        this.settings = trainingSettings;
        this.neuralNet = new MultiLayerPerceptron(TransferFunctionType.SIGMOID, dataSet.getInputSize(), trainingSettings.getHiddenNeurons(), dataSet.getOutputSize());
    }

    public TrainingSettings getSettings() {
        return this.settings;
    }

    public void setSettings(TrainingSettings trainingSettings) {
        this.settings = trainingSettings;
    }

    public DataSet getDataset() {
        return this.dataset;
    }

    public NeuralNetwork getNeuralNet() {
        return this.neuralNet;
    }

    public TrainingStatistics getStats() {
        return this.stats;
    }
}
