package org.neuroph.samples.breastCancer;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.data.norm.MaxNormalizer;

/* loaded from: input_file:org/neuroph/samples/breastCancer/BreastCancerSample.class */
public class BreastCancerSample implements LearningEventListener {
    int total;
    int correct;
    int incorrect;
    float classificationThreshold = 0.5f;

    public void run() {
        System.out.println("Creating training and test set from file...");
        DataSet createFromFile = DataSet.createFromFile("data_sets/breast_cancer.txt", 30, 1, ",");
        createFromFile.shuffle();
        new MaxNormalizer().normalize(createFromFile);
        List<DataSet> split = createFromFile.split(70, 30);
        DataSet dataSet = split.get(0);
        DataSet dataSet2 = split.get(1);
        MultiLayerPerceptron multiLayerPerceptron = new MultiLayerPerceptron(30, 16, 1);
        MomentumBackpropagation momentumBackpropagation = (MomentumBackpropagation) multiLayerPerceptron.getLearningRule();
        momentumBackpropagation.addListener(this);
        momentumBackpropagation.setLearningRate(0.3d);
        momentumBackpropagation.setMaxError(0.01d);
        momentumBackpropagation.setMaxIterations(500);
        System.out.println("Training network...");
        multiLayerPerceptron.learn(dataSet);
        System.out.println("Testing network...");
        testNeuralNetwork(multiLayerPerceptron, dataSet2);
    }

    public void testNeuralNetwork(NeuralNetwork neuralNetwork, DataSet dataSet) {
        System.out.println("********************** TEST RESULT **********************");
        for (DataSetRow dataSetRow : dataSet.getRows()) {
            neuralNetwork.setInput(dataSetRow.getInput());
            neuralNetwork.calculate();
            countPredictions(interpretOutput(neuralNetwork.getOutput()), (int) dataSetRow.getDesiredOutput()[0]);
        }
        System.out.println("Total cases: " + this.total + ". ");
        System.out.println("Correctly predicted cases: " + this.correct);
        System.out.println("Incorrectly predicted cases: " + this.incorrect);
        System.out.println("Predicted correctly: " + formatDecimalNumber((this.correct / this.total) * 100.0d) + "%. ");
    }

    @Override // org.neuroph.core.events.LearningEventListener
    public void handleLearningEvent(LearningEvent learningEvent) {
        BackPropagation backPropagation = (BackPropagation) learningEvent.getSource();
        if (!learningEvent.getEventType().equals(LearningEvent.Type.LEARNING_STOPPED)) {
            System.out.println("Iteration: " + backPropagation.getCurrentIteration() + " | Network error: " + backPropagation.getTotalNetworkError());
            return;
        }
        double totalNetworkError = backPropagation.getTotalNetworkError();
        System.out.println("Training completed in " + backPropagation.getCurrentIteration() + " iterations, ");
        System.out.println("With total error: " + formatDecimalNumber(totalNetworkError));
    }

    public int interpretOutput(double[] dArr) {
        return dArr[0] >= ((double) this.classificationThreshold) ? 1 : 0;
    }

    public void countPredictions(int i, int i2) {
        if (i == i2) {
            this.correct++;
        } else {
            this.incorrect++;
        }
        this.total++;
    }

    public String formatDecimalNumber(double d) {
        return new BigDecimal(d).setScale(4, RoundingMode.HALF_UP).toString();
    }

    public static void main(String[] strArr) {
        new BreastCancerSample().run();
    }
}
