/*
 * Decompiled with CFR 0.152.
 */
package edu.uci.jforests.applications;

import edu.uci.jforests.config.TrainingConfig;
import edu.uci.jforests.dataset.Dataset;
import edu.uci.jforests.dataset.DatasetLoader;
import edu.uci.jforests.eval.AUC;
import edu.uci.jforests.eval.Accuracy;
import edu.uci.jforests.eval.BalancedYoundenIndex;
import edu.uci.jforests.eval.EvaluationMetric;
import edu.uci.jforests.eval.RMSE;
import edu.uci.jforests.learning.LearningModule;
import edu.uci.jforests.learning.LearningProgressListener;
import edu.uci.jforests.learning.boosting.GradientBoosting;
import edu.uci.jforests.learning.classification.GradientBoostingBinaryClassifier;
import edu.uci.jforests.learning.trees.Ensemble;
import edu.uci.jforests.learning.trees.decision.RandomForest;
import edu.uci.jforests.learning.trees.regression.RegressionTreeLearner;
import edu.uci.jforests.sample.Sample;
import edu.uci.jforests.util.ConfigHolder;
import edu.uci.jforests.util.Constants;
import edu.uci.jforests.util.IOUtils;
import edu.uci.jforests.util.Timer;
import edu.uci.jforests.util.concurrency.BlockingThreadPoolExecutor;
import java.io.InputStream;
import java.util.Properties;
import java.util.Random;

public class ClassificationApp {
    protected Dataset trainDataset;
    protected Dataset validDataset;
    protected LearningModule topLearner;
    protected Sample trainSet;
    protected Sample validSet;
    protected IOUtils ioUtils;
    protected EvaluationMetric evaluationMetric;
    protected Random rnd;
    protected TrainingConfig trainingConfig;
    protected ConfigHolder configHolder;
    protected LearningProgressListener progressListener = null;

    public ClassificationApp() {
        this.initIOUtils();
    }

    public void setProgressListener(LearningProgressListener progressListener) {
        this.progressListener = progressListener;
    }

    protected void initIOUtils() {
        if (this.ioUtils == null) {
            this.ioUtils = new IOUtils();
        }
    }

    protected void loadConfig() {
        this.trainingConfig = new TrainingConfig();
        this.trainingConfig.init(this.configHolder);
    }

    protected void init() throws Exception {
        BlockingThreadPoolExecutor.init(this.trainingConfig.numThreads);
        this.initDataset(this.trainDataset);
        if (this.validSet != null) {
            this.initDataset(this.validSet.dataset);
        }
        if (this.trainingConfig.featureNamesFilename != null) {
            this.trainDataset.loadFeatureNamesFromExternalResource(this.ioUtils.getInputStream(this.trainingConfig.featureNamesFilename));
        }
    }

    protected LearningModule getLearningModule(String name) throws Exception {
        int maxNumValidInstances;
        int maxNumTrainInstances = this.trainDataset.numInstances;
        int n = maxNumValidInstances = this.validDataset != null ? this.validDataset.numInstances : this.trainDataset.numInstances;
        if (name.equals("GradientBoostingBinaryClassifier")) {
            GradientBoostingBinaryClassifier learner = new GradientBoostingBinaryClassifier();
            learner.init(this.configHolder, maxNumTrainInstances, maxNumValidInstances, this.evaluationMetric);
            return learner;
        }
        if (name.equals("GradientBoosting")) {
            GradientBoosting learner = new GradientBoosting();
            learner.init(this.configHolder, maxNumTrainInstances, maxNumValidInstances, this.evaluationMetric);
            return learner;
        }
        if (name.equals("RegressionTree")) {
            RegressionTreeLearner learner = new RegressionTreeLearner();
            learner.init(this.trainDataset, this.configHolder, maxNumTrainInstances);
            return learner;
        }
        if (name.equals("RandomForest")) {
            RandomForest learner = new RandomForest();
            learner.init(this.trainDataset, this.configHolder, maxNumTrainInstances, maxNumValidInstances, this.evaluationMetric);
            return learner;
        }
        throw new Exception("Unknown algorithm: " + name);
    }

    protected EvaluationMetric getEvaluationMetric(String name) throws Exception {
        if (name.equals("AUC")) {
            return new AUC();
        }
        if (name.equals("RMSE")) {
            return new RMSE();
        }
        if (name.equals("Accuracy")) {
            return new Accuracy();
        }
        if (name.equals("BalancedYoundenIndex")) {
            return new BalancedYoundenIndex();
        }
        throw new Exception("Unknown evaluation metric: " + name);
    }

    protected void createLearner() throws Exception {
        String[] parts = this.trainingConfig.learningAlgorithm.split("-");
        this.topLearner = this.getLearningModule(parts[0]);
        if (this.progressListener != null) {
            this.topLearner.setProgressListener(this.progressListener);
        }
        LearningModule curModule = this.topLearner;
        for (int i = 1; i < parts.length; ++i) {
            LearningModule newModule = this.getLearningModule(parts[i]);
            if (this.progressListener != null) {
                newModule.setProgressListener(this.progressListener);
            }
            curModule.setSubModule(newModule);
            curModule = newModule;
        }
    }

    protected void loadDataset(InputStream in, Dataset dataset) throws Exception {
        DatasetLoader.load(in, dataset);
    }

    public void loadDataset(String uri, Dataset dataset) throws Exception {
        if (dataset != null && dataset.uri != null && dataset.uri.equals(uri)) {
            dataset.needsInitialization = false;
            return;
        }
        InputStream in = this.ioUtils.getInputStream(uri);
        this.loadDataset(in, dataset);
        dataset.uri = uri;
        dataset.needsInitialization = true;
        in.close();
    }

    protected double getMeasurement(double[] scores, Sample sample) throws Exception {
        return sample.evaluate(scores, this.evaluationMetric, 1.0);
    }

    protected Dataset newDataset() {
        return new Dataset();
    }

    protected void initDataset(Dataset dataset) throws Exception {
    }

    protected Sample createSample(Dataset dataset, boolean trainSample) {
        return new Sample(dataset);
    }

    protected int getMaxTrainInstances() {
        return this.trainDataset.numInstances;
    }

    public Ensemble run(Properties configProperties) {
        try {
            this.configHolder = new ConfigHolder(configProperties);
            this.loadConfig();
            if (!this.trainingConfig.validate(this.ioUtils)) {
                System.out.println("Error: " + this.trainingConfig.getErrorMessage());
                return null;
            }
            this.rnd = new Random(this.trainingConfig.randomSeed);
            System.out.println("Loading datasets...");
            if (this.trainDataset == null) {
                this.trainDataset = this.newDataset();
            }
            this.loadDataset(this.trainingConfig.trainFilename, this.trainDataset);
            int maxInstances = this.getMaxTrainInstances();
            if (this.trainingConfig.validFilename != null) {
                if (this.validDataset == null) {
                    this.validDataset = this.newDataset();
                }
                this.loadDataset(this.trainingConfig.validFilename, this.validDataset);
                if (this.validDataset.numInstances > maxInstances) {
                    maxInstances = this.validDataset.numInstances;
                }
            } else {
                this.validDataset = null;
            }
            System.out.println("Finished loading datasets.");
            Constants.init(maxInstances);
            Sample allTrainSample = this.createSample(this.trainDataset, true);
            this.trainSet = allTrainSample.getRandomSubSample(this.trainingConfig.trainFraction, this.rnd);
            if (this.validDataset != null) {
                this.validSet = this.createSample(this.validDataset, false);
                if (this.trainingConfig.validFraction < 1.0) {
                    this.validSet = this.validSet.getRandomSubSample(this.trainingConfig.validFraction, this.rnd);
                }
            } else if (this.trainingConfig.validOutOfTrain) {
                this.validSet = allTrainSample.getOutOfSample(this.trainSet);
            }
            this.init();
            this.evaluationMetric = this.getEvaluationMetric(this.trainingConfig.evaluationMetric);
            this.createLearner();
            Timer timer = new Timer();
            timer.start();
            Ensemble ensemble = this.topLearner.learn(this.trainSet, this.validSet);
            System.out.println("Time taken to build model: " + (double)timer.getElapsedMillis() / 1000.0 + " seconds.");
            return ensemble;
        }
        catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }

    public int getTrainingRandomSeed() {
        return this.trainingConfig.randomSeed;
    }

    public static void shutdown() {
        BlockingThreadPoolExecutor executor = BlockingThreadPoolExecutor.getInstance();
        if (executor != null && !executor.isShutdown()) {
            executor.shutdownNow();
        }
    }

    public EvaluationMetric getEvaluationMetric() {
        return this.evaluationMetric;
    }

    public double getValidMeasurement() throws Exception {
        return this.topLearner.getValidationMeasurement();
    }

    public Sample getTrainSample() {
        return this.trainSet;
    }

    public Sample getValidSample() {
        return this.validSet;
    }

    public ConfigHolder getConfigHolder() {
        return this.configHolder;
    }

    public IOUtils getIOUtils() {
        return this.ioUtils;
    }

    public LearningProgressListener getProgressListener() {
        return this.progressListener;
    }
}

