/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.samples.forestCover;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.samples.forestCover.Config;
import org.neuroph.util.data.norm.MaxNormalizer;

public class GenerateData {
    private Config config;

    public GenerateData(Config config) {
        this.config = config;
    }

    public void createTrainingAndTestSet() {
        DataSet dataSet = this.createDataSet();
        dataSet.shuffle();
        DataSet[] trainingAndTestSet = dataSet.createTrainingAndTestSubsets(75, 25);
        DataSet trainingSet = trainingAndTestSet[0];
        System.out.println("Saving training set to file...");
        trainingSet.save(this.config.getTrainingFileName());
        System.out.println("Training set successfully saved!");
        DataSet testSet = trainingAndTestSet[1];
        System.out.println("Normalizing test set...");
        MaxNormalizer nor = new MaxNormalizer();
        nor.normalize(testSet);
        System.out.println("Saving normalized test set to file...");
        testSet.shuffle();
        testSet.save(this.config.getTestFileName());
        System.out.println("Normalized test set successfully saved!");
        System.out.println("Training set size: " + trainingSet.getRows().size() + " rows. ");
        System.out.println("Test set size: " + testSet.getRows().size() + " rows. ");
        System.out.println("-----------------------------------------------");
        double percentTraining = (double)trainingSet.getRows().size() * 100.0 / (double)dataSet.getRows().size();
        double percentTest = (double)testSet.getRows().size() * 100.0 / (double)dataSet.getRows().size();
        System.out.println("Training set takes " + this.formatDecimalNumber(percentTraining) + "% of main data set. ");
        System.out.println("Test set takes " + this.formatDecimalNumber(percentTest) + "% of main data set. ");
    }

    private DataSet createDataSet() {
        DataSet dataSet = DataSet.createFromFile(this.config.getDataFilePath(), 54, 7, ",");
        System.out.println("Main data set size: " + dataSet.getRows().size() + " rows. ");
        return dataSet;
    }

    private String formatDecimalNumber(double number) {
        return new BigDecimal(number).setScale(3, RoundingMode.HALF_UP).toString();
    }

    public void createBalancedTrainingSet(int count) {
        DataSet balanced = new DataSet(54, 7);
        int firstType = 0;
        int secondType = 0;
        int thirdType = 0;
        int fourthType = 0;
        int fifthType = 0;
        int sixthType = 0;
        int seventhType = 0;
        DataSet trainingSet = DataSet.load(this.config.getTrainingFileName());
        List<DataSetRow> rows = trainingSet.getRows();
        System.out.println("Test set size: " + rows.size() + " rows. ");
        block9: for (DataSetRow row : rows) {
            double[] DesiredOutput = row.getDesiredOutput();
            int index = -1;
            for (int i = 0; i < DesiredOutput.length; ++i) {
                if (DesiredOutput[i] != 1.0) continue;
                index = i;
                break;
            }
            switch (index + 1) {
                case 1: {
                    if (firstType >= count) continue block9;
                    balanced.addRow(row);
                    ++firstType;
                    continue block9;
                }
                case 2: {
                    if (secondType >= count) continue block9;
                    balanced.addRow(row);
                    ++secondType;
                    continue block9;
                }
                case 3: {
                    if (thirdType >= count) continue block9;
                    balanced.addRow(row);
                    ++thirdType;
                    continue block9;
                }
                case 4: {
                    if (fourthType >= count) continue block9;
                    balanced.addRow(row);
                    ++fourthType;
                    continue block9;
                }
                case 5: {
                    if (fifthType >= count) continue block9;
                    balanced.addRow(row);
                    ++fifthType;
                    continue block9;
                }
                case 6: {
                    if (sixthType >= count) continue block9;
                    balanced.addRow(row);
                    ++sixthType;
                    continue block9;
                }
                case 7: {
                    if (seventhType >= count) continue block9;
                    balanced.addRow(row);
                    ++seventhType;
                    continue block9;
                }
            }
            System.out.println("Error with output vector size! ");
        }
        System.out.println("Balanced test set size: " + balanced.getRows().size() + " rows. ");
        System.out.println("Samples per tree: ");
        System.out.println("First type: " + firstType + " samples. ");
        System.out.println("Second type: " + secondType + " samples. ");
        System.out.println("Third type: " + thirdType + " samples. ");
        System.out.println("Fourth type: " + fourthType + " samples. ");
        System.out.println("Fifth type: " + fifthType + " samples. ");
        System.out.println("Sixth type: " + sixthType + " samples. ");
        System.out.println("Seventh type: " + seventhType + " samples. ");
        balanced.save(this.config.getBalancedFileName());
    }

    public void normalizeBalancedTrainingSet() {
        DataSet dataSet = DataSet.load(this.config.getBalancedFileName());
        MaxNormalizer normalizer = new MaxNormalizer();
        normalizer.normalize(dataSet);
        System.out.println("Saving normalized training data set to file... ");
        dataSet.shuffle();
        dataSet.shuffle();
        dataSet.save(this.config.getNormalizedBalancedFileName());
        System.out.println("Normalized training data set successfully saved!");
    }
}

