/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.trees.ExtraTree;
import jsat.classifiers.trees.ImportanceByUses;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.MDI;
import jsat.classifiers.trees.TreeFeatureImportanceInference;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.exceptions.FailedToFitException;
import jsat.math.OnLineStatistics;
import jsat.regression.RegressionDataSet;
import jsat.utils.concurrent.ParallelUtils;

public class ERTrees
extends ExtraTree {
    private static final long serialVersionUID = 7139392253403373132L;
    private ExtraTree baseTree = new ExtraTree();
    private boolean useDefaultSelectionCount = true;
    private boolean useDefaultStopSize = true;
    private ExtraTree[] forrest;
    private int forrestSize;

    public ERTrees() {
        this(100);
    }

    public ERTrees(int forrestSize) {
        this.forrestSize = forrestSize;
    }

    public ERTrees(ERTrees toCopy) {
        super(toCopy);
        this.forrestSize = toCopy.forrestSize;
        this.useDefaultSelectionCount = toCopy.useDefaultSelectionCount;
        this.useDefaultStopSize = toCopy.useDefaultStopSize;
        this.baseTree = toCopy.baseTree.clone();
        if (toCopy.forrest != null) {
            this.forrest = new ExtraTree[toCopy.forrest.length];
            for (int i = 0; i < toCopy.forrest.length; ++i) {
                this.forrest[i] = toCopy.forrest[i].clone();
            }
        }
    }

    public <Type extends DataSet> OnLineStatistics[] evaluateFeatureImportance(DataSet<Type> data) {
        if (data instanceof ClassificationDataSet) {
            return this.evaluateFeatureImportance(data, new MDI(ImpurityScore.ImpurityMeasure.GINI));
        }
        return this.evaluateFeatureImportance(data, new ImportanceByUses());
    }

    public <Type extends DataSet> OnLineStatistics[] evaluateFeatureImportance(DataSet<Type> data, TreeFeatureImportanceInference imp) {
        OnLineStatistics[] importances = new OnLineStatistics[data.getNumFeatures()];
        for (int i = 0; i < importances.length; ++i) {
            importances[i] = new OnLineStatistics();
        }
        for (ExtraTree tree : this.forrest) {
            double[] feats = imp.getImportanceStats(tree, data);
            for (int i = 0; i < importances.length; ++i) {
                importances[i].add(feats[i]);
            }
        }
        return importances;
    }

    public void setUseDefaultSelectionCount(boolean useDefaultSelectionCount) {
        this.useDefaultSelectionCount = useDefaultSelectionCount;
    }

    public boolean getUseDefaultSelectionCount() {
        return this.useDefaultSelectionCount;
    }

    public void setUseDefaultStopSize(boolean useDefaultStopSize) {
        this.useDefaultStopSize = useDefaultStopSize;
    }

    public boolean getUseDefaultStopSize() {
        return this.useDefaultStopSize;
    }

    public void setForrestSize(int forrestSize) {
        this.forrestSize = forrestSize;
    }

    public int getForrestSize() {
        return this.forrestSize;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(this.predicting.getNumOfCategories());
        for (ExtraTree tree : this.forrest) {
            cr.incProb(tree.classify(data).mostLikely(), 1.0);
        }
        cr.normalize();
        return cr;
    }

    private void doTraining(boolean parallel, DataSet dataSet) throws FailedToFitException {
        this.forrest = new ExtraTree[this.forrestSize];
        ParallelUtils.run(parallel, this.forrestSize, (start, end) -> {
            if (dataSet instanceof ClassificationDataSet) {
                ClassificationDataSet cds = (ClassificationDataSet)dataSet;
                for (int i = start; i < end; ++i) {
                    this.forrest[i] = this.baseTree.clone();
                    this.forrest[i].train(cds);
                }
            } else if (dataSet instanceof RegressionDataSet) {
                RegressionDataSet rds = (RegressionDataSet)dataSet;
                for (int i = start; i < end; ++i) {
                    this.forrest[i] = this.baseTree.clone();
                    this.forrest[i].train(rds);
                }
            } else {
                throw new RuntimeException("BUG: Please report");
            }
        });
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        if (this.useDefaultSelectionCount) {
            this.baseTree.setSelectionCount((int)Math.max(Math.round(Math.sqrt(dataSet.getNumFeatures())), 1L));
        }
        if (this.useDefaultStopSize) {
            this.baseTree.setStopSize(2);
        }
        this.predicting = dataSet.getPredicting();
        this.doTraining(parallel, dataSet);
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }

    @Override
    public double regress(DataPoint data) {
        double mean = 0.0;
        for (ExtraTree tree : this.forrest) {
            mean += tree.regress(data);
        }
        return mean / (double)this.forrest.length;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        if (this.useDefaultSelectionCount) {
            this.baseTree.setSelectionCount(dataSet.getNumFeatures());
        }
        if (this.useDefaultStopSize) {
            this.baseTree.setStopSize(5);
        }
        this.doTraining(parallel, dataSet);
    }

    @Override
    public ERTrees clone() {
        return new ERTrees(this);
    }

    @Override
    public TreeNodeVisitor getTreeNodeVisitor() {
        throw new UnsupportedOperationException("Can not get the tree node vistor becase ERTrees is really a ensemble");
    }
}

