/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.breastCancer;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.data.norm.MaxNormalizer;

public class BrestCancerSample
implements LearningEventListener {
    public int[] count = new int[3];
    public int[] correct = new int[3];
    int unpredicted = 0;

    public static void main(String[] args) {
        new BrestCancerSample().run();
    }

    public void run() {
        System.out.println("Creating training and test set from file...");
        String trainingSetFileName = "data_sets/breast cancer.txt";
        int inputsCount = 30;
        int outputsCount = 2;
        DataSet dataSet = DataSet.createFromFile(trainingSetFileName, inputsCount, outputsCount, ",");
        dataSet.shuffle();
        MaxNormalizer normalizer = new MaxNormalizer();
        normalizer.normalize(dataSet);
        List<DataSet> trainingAndTestSet = dataSet.split(70, 30);
        DataSet trainingSet = trainingAndTestSet.get(0);
        DataSet testSet = trainingAndTestSet.get(1);
        System.out.println("Creating neural network...");
        MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(inputsCount, 16, outputsCount);
        MomentumBackpropagation learningRule = (MomentumBackpropagation)neuralNet.getLearningRule();
        learningRule.addListener(this);
        learningRule.setLearningRate(0.3);
        learningRule.setMaxError(0.001);
        learningRule.setMaxIterations(5000);
        System.out.println("Training network...");
        neuralNet.learn(trainingSet);
        System.out.println("Testing network...\n\n");
        this.testNeuralNetwork(neuralNet, testSet);
        System.out.println("Done.");
        System.out.println("**************************************************");
    }

    public void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {
        System.out.println("**************************************************");
        System.out.println("**********************RESULT**********************");
        System.out.println("**************************************************");
        for (DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();
            int predicted = BrestCancerSample.maxOutput(networkOutput);
            double[] networkDesiredOutput = testSetRow.getDesiredOutput();
            int ideal = BrestCancerSample.maxOutput(networkDesiredOutput);
            this.keepScore(predicted, ideal);
        }
        System.out.println("Total cases: " + this.count[2] + ". ");
        System.out.println("Correctly predicted cases: " + this.correct[2] + ". ");
        System.out.println("Incorrectly predicted cases: " + (this.count[2] - this.correct[2] - this.unpredicted) + ". ");
        System.out.println("Unrecognized cases: " + this.unpredicted + ". ");
        double percentTotal = (double)this.correct[2] * 100.0 / (double)this.count[2];
        System.out.println("Predicted correctly: " + this.formatDecimalNumber(percentTotal) + "%. ");
        double percentM = (double)this.correct[0] * 100.0 / (double)this.count[0];
        System.out.println("Prediction for 'M (malignant)' => (Correct/total): " + this.correct[0] + "/" + this.count[0] + "(" + this.formatDecimalNumber(percentM) + "%). ");
        double percentB = (double)this.correct[1] * 100.0 / (double)this.count[1];
        System.out.println("Prediction for 'B (benign)' => (Correct/total): " + this.correct[1] + "/" + this.count[1] + "(" + this.formatDecimalNumber(percentB) + "%). ");
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType().equals((Object)LearningEvent.Type.LEARNING_STOPPED)) {
            double error = bp.getTotalNetworkError();
            System.out.println("Training completed in " + bp.getCurrentIteration() + " iterations, ");
            System.out.println("With total error: " + this.formatDecimalNumber(error));
        } else {
            System.out.println("Iteration: " + bp.getCurrentIteration() + " | Network error: " + bp.getTotalNetworkError());
        }
    }

    public static int maxOutput(double[] array) {
        double max = array[0];
        int index = 0;
        for (int i = 0; i < array.length; ++i) {
            if (!(array[i] > max)) continue;
            index = i;
            max = array[i];
        }
        if (max < 0.5) {
            return -1;
        }
        return index;
    }

    public void keepScore(int prediction, int ideal) {
        int n = ideal;
        this.count[n] = this.count[n] + 1;
        this.count[2] = this.count[2] + 1;
        if (prediction == ideal) {
            int n2 = ideal;
            this.correct[n2] = this.correct[n2] + 1;
            this.correct[2] = this.correct[2] + 1;
        }
        if (prediction == -1) {
            ++this.unpredicted;
        }
    }

    public String formatDecimalNumber(double number) {
        return new BigDecimal(number).setScale(4, RoundingMode.HALF_UP).toString();
    }
}

