/*
 * Decompiled with CFR 0.152.
 */
package jsat.datatransform.featureselection;

import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.datatransform.featureselection.SBS;
import jsat.datatransform.featureselection.SFS;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;
import jsat.utils.random.RandomUtil;

public class BDS
implements DataTransform {
    private static final long serialVersionUID = 8633823674617843754L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;
    private int featureCount;
    private int folds;
    private Object evaluator;

    public BDS(BDS toClone) {
        this.featureCount = toClone.featureCount;
        this.folds = toClone.folds;
        this.evaluator = toClone.evaluator;
        if (toClone.finalTransform != null) {
            this.finalTransform = toClone.finalTransform.clone();
            this.catSelected = new IntSet(toClone.catSelected);
            this.numSelected = new IntSet(toClone.numSelected);
        }
    }

    public BDS(int featureCount, Classifier evaluator, int folds) {
        this.setFeatureCount(featureCount);
        this.setFolds(folds);
        this.setEvaluator(evaluator);
    }

    public BDS(int featureCount, ClassificationDataSet dataSet, Classifier evaluator, int folds) {
        this.search(dataSet, featureCount, folds, evaluator);
    }

    public BDS(int featureCount, Regressor evaluator, int folds) {
        this.setFeatureCount(featureCount);
        this.setFolds(folds);
        this.setEvaluator(evaluator);
    }

    public BDS(int featureCount, RegressionDataSet dataSet, Regressor evaluator, int folds) {
        this(featureCount, evaluator, folds);
        this.search(dataSet, featureCount, folds, evaluator);
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        return this.finalTransform.transform(dp);
    }

    @Override
    public BDS clone() {
        return new BDS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    @Override
    public void fit(DataSet data) {
        this.search(data, this.featureCount, this.folds, this.evaluator);
    }

    private void search(DataSet dataSet, int maxFeatures, int folds, Object evaluator) {
        Random rand = RandomUtil.getRandom();
        int nF = dataSet.getNumFeatures();
        int nCat = dataSet.getNumCategoricalVars();
        this.catSelected = new IntSet(dataSet.getNumCategoricalVars());
        this.numSelected = new IntSet(dataSet.getNumNumericalVars());
        IntSet availableSFS = new IntSet();
        ListUtils.addRange(availableSFS, 0, nF, 1);
        IntSet catToRemoveSFS = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numToRemoveSFS = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(catToRemoveSFS, 0, nCat, 1);
        ListUtils.addRange(numToRemoveSFS, 0, nF - nCat, 1);
        IntSet availableSBS = new IntSet();
        ListUtils.addRange(availableSBS, 0, nF, 1);
        IntSet catSelecteedSBS = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numSelectedSBS = new IntSet(dataSet.getNumNumericalVars());
        IntSet catToRemoveSBS = new IntSet(dataSet.getNumCategoricalVars());
        IntSet numToRemoveSBS = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(catSelecteedSBS, 0, nCat, 1);
        ListUtils.addRange(numSelectedSBS, 0, nF - nCat, 1);
        double[] pBestScore0 = new double[]{Double.POSITIVE_INFINITY};
        double[] pBestScore1 = new double[]{Double.POSITIVE_INFINITY};
        int max = Math.min(maxFeatures, nF / 2);
        for (int i = 0; i < max; ++i) {
            int mustKeep = SFS.SFSSelectFeature(availableSFS, dataSet, catToRemoveSFS, numToRemoveSFS, this.catSelected, this.numSelected, evaluator, folds, rand, pBestScore0, max);
            availableSBS.remove((Object)mustKeep);
            SFS.removeFeature(mustKeep, nCat, catToRemoveSBS, numToRemoveSBS);
            int mustRemove = SBS.SBSRemoveFeature(availableSBS, dataSet, catToRemoveSBS, numToRemoveSBS, catSelecteedSBS, numSelectedSBS, evaluator, folds, rand, max, pBestScore1, 0.0);
            availableSFS.remove((Object)mustRemove);
            SFS.addFeature(mustRemove, nCat, catToRemoveSFS, numToRemoveSFS);
        }
        catSelecteedSBS.clear();
        numToRemoveSBS.clear();
        ListUtils.addRange(catSelecteedSBS, 0, nCat, 1);
        ListUtils.addRange(numSelectedSBS, 0, nF - nCat, 1);
        catSelecteedSBS.removeAll(this.catSelected);
        numSelectedSBS.removeAll(this.numSelected);
        this.finalTransform = new RemoveAttributeTransform(dataSet, catSelecteedSBS, numSelectedSBS);
    }

    public void setFeatureCount(int featureCount) {
        if (featureCount < 1) {
            throw new IllegalArgumentException("Number of features to select must be positive, not " + featureCount);
        }
        this.featureCount = featureCount;
    }

    public int getFeatureCount() {
        return this.featureCount;
    }

    public void setFolds(int folds) {
        if (folds <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + folds);
        }
        this.folds = folds;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object evaluator) {
        this.evaluator = evaluator;
    }
}

