/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.evaluation.Accuracy;
import jsat.classifiers.evaluation.ClassificationScore;
import jsat.classifiers.trees.TreeFeatureImportanceInference;
import jsat.classifiers.trees.TreeLearner;
import jsat.classifiers.trees.TreeNodeVisitor;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.evaluation.MeanSquaredError;
import jsat.regression.evaluation.RegressionScore;
import jsat.utils.random.RandomUtil;

public class MDA
implements TreeFeatureImportanceInference {
    private ClassificationScore cs_base = new Accuracy();
    private RegressionScore rs_base = new MeanSquaredError();

    @Override
    public <Type extends DataSet> double[] getImportanceStats(TreeLearner model, DataSet<Type> data) {
        double[] features;
        block7: {
            Random rand;
            block6: {
                features = new double[data.getNumFeatures()];
                rand = RandomUtil.getRandom();
                if (!(data instanceof ClassificationDataSet)) break block6;
                ClassificationDataSet cds = (ClassificationDataSet)data;
                ClassificationScore cs = this.cs_base.clone();
                cs.prepare(cds.getPredicting());
                for (int i = 0; i < cds.getSampleSize(); ++i) {
                    DataPoint dp = cds.getDataPoint(i);
                    cs.addResult(((Classifier)((Object)model)).classify(dp), cds.getDataPointCategory(i), dp.getWeight());
                }
                double baseScore = cs.getScore();
                boolean percentIncrease = cs.lowerIsBetter();
                for (int j = 0; j < data.getNumFeatures(); ++j) {
                    cs.prepare(cds.getPredicting());
                    for (int i = 0; i < cds.getSampleSize(); ++i) {
                        DataPoint dp = cds.getDataPoint(i);
                        int true_label = cds.getDataPointCategory(i);
                        TreeNodeVisitor curNode = this.walkCorruptedPath(model, dp, j, rand);
                        cs.addResult(curNode.localClassify(dp), true_label, dp.getWeight());
                    }
                    double newScore = cs.getScore();
                    features[j] = percentIncrease ? (newScore - baseScore) / (baseScore + 0.001) : (baseScore - newScore) / (baseScore + 0.001);
                }
                break block7;
            }
            if (!(data instanceof RegressionDataSet)) break block7;
            RegressionDataSet rds = (RegressionDataSet)data;
            RegressionScore rs = this.rs_base.clone();
            rs.prepare();
            for (int i = 0; i < rds.getSampleSize(); ++i) {
                DataPoint dp = rds.getDataPoint(i);
                rs.addResult(((Regressor)((Object)model)).regress(dp), rds.getTargetValue(i), dp.getWeight());
            }
            double baseScore = rs.getScore();
            boolean percentIncrease = rs.lowerIsBetter();
            for (int j = 0; j < data.getNumFeatures(); ++j) {
                rs.prepare();
                for (int i = 0; i < rds.getSampleSize(); ++i) {
                    DataPoint dp = rds.getDataPoint(i);
                    double true_label = rds.getTargetValue(i);
                    TreeNodeVisitor curNode = this.walkCorruptedPath(model, dp, j, rand);
                    rs.addResult(curNode.localRegress(dp), true_label, dp.getWeight());
                }
                double newScore = rs.getScore();
                features[j] = percentIncrease ? (newScore - baseScore) / (baseScore + 0.001) : (baseScore - newScore) / (baseScore + 0.001);
            }
        }
        return features;
    }

    private TreeNodeVisitor walkCorruptedPath(TreeLearner model, DataPoint dp, int j, Random rand) {
        TreeNodeVisitor curNode = model.getTreeNodeVisitor();
        while (!curNode.isLeaf()) {
            int path = curNode.getPath(dp);
            int numChild = curNode.childrenCount();
            if (curNode.featuresUsed().contains(j)) {
                path = (path + rand.nextInt(numChild)) % numChild;
            }
            if (curNode.isPathDisabled(path)) break;
            curNode = curNode.getChild(path);
        }
        return curNode;
    }
}

