/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.regression;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.utilities.MapMethods;
import com.datumbox.framework.core.machinelearning.MLBuilder;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractRegressor;
import com.datumbox.framework.core.machinelearning.common.dataobjects.TrainableBundle;
import com.datumbox.framework.core.machinelearning.common.interfaces.StepwiseCompatible;
import com.datumbox.framework.core.machinelearning.common.interfaces.Trainable;
import java.util.HashSet;
import java.util.Map;

public class StepwiseRegression
extends AbstractRegressor<ModelParameters, TrainingParameters> {
    private static final String REG_KEY = "reg";
    private final TrainableBundle bundle;

    protected StepwiseRegression(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
        this.bundle = new TrainableBundle(configuration.getStorageConfiguration().getStorageNameSeparator());
    }

    protected StepwiseRegression(String storageName, Configuration configuration) {
        super(storageName, configuration);
        this.bundle = new TrainableBundle(configuration.getStorageConfiguration().getStorageNameSeparator());
    }

    @Override
    protected void _predict(Dataframe newData) {
        this.initBundle();
        AbstractRegressor mlregressor = (AbstractRegressor)this.bundle.get(REG_KEY);
        mlregressor.predict(newData);
    }

    @Override
    protected void _fit(Dataframe trainingData) {
        Map<Object, Double> pvalues;
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Configuration configuration = this.knowledgeBase.getConfiguration();
        this.resetBundle();
        int maxIterations = trainingParameters.getMaxIterations();
        double aOut = trainingParameters.getAout();
        Dataframe copiedTrainingData = trainingData.copy();
        for (int iteration = 0; iteration < maxIterations && !(pvalues = this.runRegression(copiedTrainingData)).isEmpty(); ++iteration) {
            pvalues.remove("~CONSTANT");
            Map.Entry<Object, Double> maxPvalueEntry = MapMethods.selectMaxKeyValue(pvalues);
            if (maxPvalueEntry.getValue() <= aOut) break;
            HashSet<Object> removedFeatures = new HashSet<Object>();
            removedFeatures.add(maxPvalueEntry.getKey());
            copiedTrainingData.dropXColumns(removedFeatures);
            if (copiedTrainingData.xColumnSize() == 0) break;
        }
        AbstractRegressor mlregressor = (AbstractRegressor)MLBuilder.create(((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getRegressionTrainingParameters(), configuration);
        mlregressor.fit(copiedTrainingData);
        this.bundle.put(REG_KEY, mlregressor);
        copiedTrainingData.close();
    }

    @Override
    public void save(String storageName) {
        this.initBundle();
        super.save(storageName);
        String knowledgeBaseName = this.createKnowledgeBaseName(storageName, this.knowledgeBase.getConfiguration().getStorageConfiguration().getStorageNameSeparator());
        this.bundle.save(knowledgeBaseName);
    }

    @Override
    public void delete() {
        this.initBundle();
        this.bundle.delete();
        super.delete();
    }

    @Override
    public void close() {
        this.initBundle();
        this.bundle.close();
        super.close();
    }

    private void resetBundle() {
        this.bundle.delete();
    }

    private void initBundle() {
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Configuration configuration = this.knowledgeBase.getConfiguration();
        String storageName = this.knowledgeBase.getStorageEngine().getStorageName();
        String separator = configuration.getStorageConfiguration().getStorageNameSeparator();
        if (!this.bundle.containsKey(REG_KEY)) {
            AbstractTrainer.AbstractTrainingParameters mlParams = trainingParameters.getRegressionTrainingParameters();
            this.bundle.put(REG_KEY, (Trainable)MLBuilder.load(mlParams.getTClass(), storageName + separator + REG_KEY, configuration));
        }
    }

    private Map<Object, Double> runRegression(Dataframe trainingData) {
        AbstractRegressor mlregressor = (AbstractRegressor)MLBuilder.create(((TrainingParameters)this.knowledgeBase.getTrainingParameters()).getRegressionTrainingParameters(), this.knowledgeBase.getConfiguration());
        mlregressor.fit(trainingData);
        Map<Object, Double> pvalues = ((StepwiseCompatible)((Object)mlregressor)).getFeaturePvalues();
        mlregressor.close();
        return pvalues;
    }

    public static class TrainingParameters
    extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1L;
        private int maxIterations = Integer.MAX_VALUE;
        private double aout = 0.05;
        private AbstractTrainer.AbstractTrainingParameters regressionTrainingParameters;

        public int getMaxIterations() {
            return this.maxIterations;
        }

        public void setMaxIterations(int maxIterations) {
            this.maxIterations = maxIterations;
        }

        public double getAout() {
            return this.aout;
        }

        public void setAout(double aout) {
            this.aout = aout;
        }

        public AbstractTrainer.AbstractTrainingParameters getRegressionTrainingParameters() {
            return this.regressionTrainingParameters;
        }

        public void setRegressionTrainingParameters(AbstractTrainer.AbstractTrainingParameters regressionTrainingParameters) {
            this.regressionTrainingParameters = regressionTrainingParameters;
        }
    }

    public static class ModelParameters
    extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1L;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }
}

