/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.trees;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.classifiers.trees.ImpurityScore;
import jsat.classifiers.trees.TreeFeatureImportanceInference;
import jsat.classifiers.trees.TreeLearner;
import jsat.classifiers.trees.TreeNodeVisitor;

public class MDI
implements TreeFeatureImportanceInference {
    private ImpurityScore.ImpurityMeasure im;

    public MDI(ImpurityScore.ImpurityMeasure im) {
        this.im = im;
    }

    public MDI() {
        this(ImpurityScore.ImpurityMeasure.GINI);
    }

    @Override
    public <Type extends DataSet> double[] getImportanceStats(TreeLearner model, DataSet<Type> data) {
        double[] features = new double[data.getNumFeatures()];
        if (!(data instanceof ClassificationDataSet)) {
            throw new RuntimeException("MDI currently only supports classification datasets");
        }
        List<DataPointPair<Integer>> allData = ((ClassificationDataSet)data).getAsDPPList();
        int K = ((ClassificationDataSet)data).getClassSize();
        ImpurityScore score = new ImpurityScore(K, this.im);
        for (DataPointPair<Integer> d : allData) {
            score.addPoint(d.getDataPoint(), (int)d.getPair());
        }
        this.visit(model.getTreeNodeVisitor(), score, allData, features, score.getSumOfWeights(), K);
        return features;
    }

    /*
     * WARNING - void declaration
     */
    private void visit(TreeNodeVisitor node, ImpurityScore score, List<DataPointPair<Integer>> data, double[] features, double N, int K) {
        void var17_22;
        if (node == null || node.isLeaf()) {
            return;
        }
        double curScore = score.getScore();
        double curN = score.getSumOfWeights();
        ArrayList<List<DataPointPair<Integer>>> splitsData = new ArrayList<List<DataPointPair<Integer>>>(node.childrenCount());
        ArrayList<ImpurityScore> splitScores = new ArrayList<ImpurityScore>(node.childrenCount());
        splitsData.add(data);
        splitScores.add(score);
        for (int i = 0; i < node.childrenCount() - 1; ++i) {
            splitsData.add(new ArrayList());
            splitScores.add(new ImpurityScore(K, this.im));
        }
        ListIterator<DataPointPair<Integer>> iter = data.listIterator();
        while (iter.hasNext()) {
            DataPointPair<Integer> curPoint = iter.next();
            int tc = curPoint.getPair();
            DataPoint dataPoint = curPoint.getDataPoint();
            int path = node.getPath(dataPoint);
            if (path < 0) {
                score.removePoint(dataPoint, tc);
                continue;
            }
            if (path <= 0) continue;
            score.removePoint(dataPoint, tc);
            ((ImpurityScore)splitScores.get(path)).addPoint(dataPoint, tc);
            ((List)splitsData.get(path)).add(curPoint);
            iter.remove();
        }
        double chageInImp = curScore;
        for (ImpurityScore impurityScore : splitScores) {
            chageInImp -= impurityScore.getScore() * (impurityScore.getSumOfWeights() / (1.0E-5 + curN));
        }
        Collection<Integer> featuresUsed = node.featuresUsed();
        Iterator<Integer> iterator = featuresUsed.iterator();
        while (iterator.hasNext()) {
            int feature;
            int n = feature = iterator.next().intValue();
            features[n] = features[n] + chageInImp * curN / N;
        }
        boolean bl = false;
        while (var17_22 < splitScores.size()) {
            this.visit(node.getChild((int)var17_22), (ImpurityScore)splitScores.get((int)var17_22), (List)splitsData.get((int)var17_22), features, N, K);
            ++var17_22;
        }
    }
}

