/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.diabetes;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.events.LearningEvent;
import org.neuroph.core.events.LearningEventListener;
import org.neuroph.nnet.MultiLayerPerceptron;
import org.neuroph.nnet.learning.BackPropagation;
import org.neuroph.nnet.learning.MomentumBackpropagation;
import org.neuroph.util.data.norm.MaxNormalizer;

public class DiabetesSample
implements LearningEventListener {
    int total;
    int correct;
    int incorrect;
    float classificationThreshold = 0.5f;

    public void run() {
        System.out.println("Creating training and test set from file...");
        String trainingSetFileName = "data_sets/diabetes.txt";
        int inputsCount = 8;
        int outputsCount = 1;
        DataSet dataSet = DataSet.createFromFile(trainingSetFileName, inputsCount, outputsCount, ",");
        dataSet.shuffle();
        MaxNormalizer normalizer = new MaxNormalizer();
        normalizer.normalize(dataSet);
        List<DataSet> trainingAndTestSet = dataSet.split(70, 30);
        DataSet trainingSet = trainingAndTestSet.get(0);
        DataSet testSet = trainingAndTestSet.get(1);
        System.out.println("Creating neural network...");
        MultiLayerPerceptron neuralNet = new MultiLayerPerceptron(inputsCount, 20, 10, outputsCount);
        MomentumBackpropagation learningRule = (MomentumBackpropagation)neuralNet.getLearningRule();
        learningRule.addListener(this);
        learningRule.setLearningRate(0.6);
        learningRule.setMaxError(0.07);
        learningRule.setMaxIterations(100000);
        System.out.println("Training network...");
        neuralNet.learn(trainingSet);
        System.out.println("Testing network...");
        this.testNeuralNetwork(neuralNet, testSet);
    }

    public void testNeuralNetwork(NeuralNetwork neuralNet, DataSet testSet) {
        System.out.println("**********************RESULT**********************");
        for (DataSetRow testSetRow : testSet.getRows()) {
            neuralNet.setInput(testSetRow.getInput());
            neuralNet.calculate();
            double[] networkOutput = neuralNet.getOutput();
            int predicted = this.interpretOutput(networkOutput);
            double[] desiredOutput = testSetRow.getDesiredOutput();
            int target = (int)desiredOutput[0];
            this.countPredictions(predicted, target);
        }
        System.out.println("Total cases: " + this.total + ". ");
        System.out.println("Correctly predicted cases: " + this.correct);
        System.out.println("Incorrectly predicted cases: " + this.incorrect);
        double percentTotal = (double)this.correct / (double)this.total * 100.0;
        System.out.println("Predicted correctly: " + this.formatDecimalNumber(percentTotal) + "%. ");
    }

    @Override
    public void handleLearningEvent(LearningEvent event) {
        BackPropagation bp = (BackPropagation)event.getSource();
        if (event.getEventType().equals((Object)LearningEvent.Type.LEARNING_STOPPED)) {
            double error = bp.getTotalNetworkError();
            System.out.println("Training completed in " + bp.getCurrentIteration() + " iterations, ");
            System.out.println("With total error: " + this.formatDecimalNumber(error));
        } else {
            System.out.println("Iteration: " + bp.getCurrentIteration() + " | Network error: " + bp.getTotalNetworkError());
        }
    }

    public int interpretOutput(double[] array) {
        if (array[0] >= (double)this.classificationThreshold) {
            return 1;
        }
        return 0;
    }

    public void countPredictions(int prediction, int target) {
        if (prediction == target) {
            ++this.correct;
        } else {
            ++this.incorrect;
        }
        ++this.total;
    }

    public String formatDecimalNumber(double number) {
        return new BigDecimal(number).setScale(4, RoundingMode.HALF_UP).toString();
    }

    public static void main(String[] args) {
        new DiabetesSample().run();
    }
}

