/*
 * Decompiled with CFR 0.152.
 */
package dt;

import dt.Instance;
import dt.Node;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class ID3 {
    Node root;
    public static ImpurityFunction impurity_entropy = new ImpurityFunction(){

        @Override
        public double calc(int a, int b) {
            double pa = (double)a / ((double)a + (double)b);
            double pb = (double)b / ((double)a + (double)b);
            double res = 0.0;
            if (a > 0) {
                res += -pa * FastMath.log((double)pa);
            }
            if (b > 0) {
                res += -pb * FastMath.log((double)pb);
            }
            return res / FastMath.log((double)2.0);
        }
    };
    public static ImpurityFunction impurity_misclassification = new ImpurityFunction(){

        @Override
        public double calc(int a, int b) {
            if (a > b) {
                return (double)b / (double)(a + b);
            }
            return (double)a / (double)(a + b);
        }
    };
    public static ChiSquareTest chi_square_001 = new ChiSquareTest(16.27);
    public static ChiSquareTest chi_square_01 = new ChiSquareTest(11.34);
    public static ChiSquareTest chi_square_05 = new ChiSquareTest(7.82);
    public static ChiSquareTest chi_square_100 = new ChiSquareTest(0.0);

    public static Node generate(List<Instance> instances, ImpurityFunction f) {
        Node root = new Node(null, instances);
        ID3.expand(root, f, chi_square_100, 0);
        return root;
    }

    public static Node generate(List<Instance> instances, ImpurityFunction f, ChiSquareTest cst) {
        Node root = new Node(null, instances);
        ID3.expand(root, f, cst, 0);
        return root;
    }

    static void expand(Node node, ImpurityFunction impurityFunction, ChiSquareTest cst, int depth) {
        double maxGain = -100000.0;
        int maxGainDecision = -1;
        int num = node.instances.size();
        int ftsNum = node.instances.get((int)0).fts.length;
        int[][] mcount = new int[Instance.FTSVALUERANGE][2];
        int parentPos = 0;
        int parentNeg = 0;
        for (int i = 0; i < node.instances.size(); ++i) {
            if (node.instances.get((int)i).label == 1) {
                ++parentPos;
                continue;
            }
            ++parentNeg;
        }
        for (int s = 0; s < node.numOfFts; ++s) {
            int i;
            int[][] count = new int[Instance.FTSVALUERANGE][2];
            for (Instance t : node.instances) {
                if (t.label == 1) {
                    int[] nArray = count[t.fts[s]];
                    nArray[1] = nArray[1] + 1;
                    continue;
                }
                int[] nArray = count[t.fts[s]];
                nArray[0] = nArray[0] + 1;
            }
            double gain = impurityFunction.calc(parentPos, parentNeg);
            for (i = 0; i < Instance.FTSVALUERANGE; ++i) {
                gain -= 1.0 * (double)(count[i][0] + count[i][1]) / (double)(parentPos + parentNeg) * impurityFunction.calc(count[i][0], count[i][1]);
            }
            if (!(gain > maxGain)) continue;
            maxGain = gain;
            maxGainDecision = s;
            for (i = 0; i < Instance.FTSVALUERANGE; ++i) {
                mcount[i][0] = count[i][0];
                mcount[i][1] = count[i][1];
            }
        }
        if (maxGain > 1.0E-10 && cst.test(mcount)) {
            node.testFts = maxGainDecision;
            ArrayList ts = new ArrayList();
            for (int i = 0; i < Instance.FTSVALUERANGE; ++i) {
                ts.add(new ArrayList());
            }
            for (Instance t : node.instances) {
                ((ArrayList)ts.get(t.fts[maxGainDecision])).add(t);
            }
            for (int i = 0; i < Instance.FTSVALUERANGE; ++i) {
                if (maxGainDecision == 16 && i == 2) {
                    boolean bl = false;
                }
                if (((ArrayList)ts.get(i)).size() <= 0) continue;
                node.children[i] = new Node(node, (List)ts.get(i));
                ID3.expand(node.children[i], impurityFunction, cst, depth + 1);
            }
        }
    }

    public void learn(ArrayList<Instance> instances, ImpurityFunction f, ChiSquareTest cts) {
        this.root = ID3.generate(instances, f, cts);
    }

    public void learn(ArrayList<Instance> instances, ImpurityFunction f) {
        this.root = ID3.generate(instances, f);
    }

    public void learn(List<Instance> instances) {
        this.root = ID3.generate(instances, impurity_entropy);
    }

    public List<Integer> classify(List<Instance> testInstances) {
        ArrayList<Integer> predictions = new ArrayList<Integer>();
        for (Instance t : testInstances) {
            int predictedCategory = this.root.classify(t);
            predictions.add(predictedCategory);
        }
        return predictions;
    }

    public static void load(String trainfile, String testfile, List<Instance> trainInstances, List<Instance> testInstances) {
        Instance ins;
        String line;
        BufferedReader br;
        int UNIQEID = 0;
        try {
            br = new BufferedReader(new FileReader(trainfile));
            while ((line = br.readLine()) != null) {
                ins = new Instance(line, UNIQEID++);
                trainInstances.add(ins);
            }
            br.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        try {
            br = new BufferedReader(new FileReader(testfile));
            while ((line = br.readLine()) != null) {
                ins = new Instance(line, UNIQEID++);
                testInstances.add(ins);
            }
            br.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static double computeAccuracy(List<Integer> predictions, List<Instance> testInstances) {
        if (predictions.size() != testInstances.size()) {
            return 0.0;
        }
        int right = 0;
        int wrong = 0;
        for (int i = 0; i < predictions.size(); ++i) {
            if (predictions.get(i) == null) {
                ++wrong;
                continue;
            }
            if (predictions.get(i) == testInstances.get((int)i).label) {
                ++right;
                continue;
            }
            ++wrong;
        }
        return (double)right * 1.0 / (double)(right + wrong);
    }

    public static void main(String[] args) {
        ArrayList<Instance> trainInstances = new ArrayList<Instance>();
        ArrayList<Instance> testInstances = new ArrayList<Instance>();
        ID3.load("training.txt", "test.txt", trainInstances, testInstances);
        ID3 id3 = new ID3();
        id3.learn(trainInstances);
        id3.root.writeTreeXML("tree_fulltree.xml");
        List<Integer> trainpredictions = id3.classify(trainInstances);
        System.out.println("ID3 with full tree on training\t" + ID3.computeAccuracy(trainpredictions, trainInstances));
        List<Integer> predictions = id3.classify(testInstances);
        System.out.println("ID3 with full tree on test\t" + ID3.computeAccuracy(predictions, testInstances));
    }

    public static class ChiSquareTest {
        double threshold;

        ChiSquareTest(double threshold) {
            this.threshold = threshold;
        }

        public boolean test_old(int[][] count) {
            double chi_square = 0.0;
            double num = count[0][0] + count[0][1] + count[1][0] + count[1][1];
            int j = 0;
            for (int i = 0; i < 2; ++i) {
                double ne = (double)(count[j][0] + count[j][1]) / num * (double)(count[0][i] + count[1][i]);
                double nl = count[j][i];
                chi_square += (nl - ne) * (nl - ne) / ne;
            }
            return chi_square > this.threshold + 1.0E-8;
        }

        public boolean test(int[][] count) {
            int j;
            int i;
            double chi_square = 0.0;
            int m_x_l = count.length;
            int m_y_l = count[0].length;
            double[] m_x = new double[m_x_l];
            double[] m_y = new double[m_y_l];
            double m = 0.0;
            for (i = 0; i < m_x_l; ++i) {
                for (j = 0; j < m_y_l; ++j) {
                    int n = i;
                    m_x[n] = m_x[n] + (double)count[i][j];
                    m += (double)count[i][j];
                }
            }
            for (int j2 = 0; j2 < m_y_l; ++j2) {
                for (int i2 = 0; i2 < m_x_l; ++i2) {
                    int n = j2;
                    m_y[n] = m_y[n] + (double)count[i2][j2];
                }
            }
            for (i = 0; i < m_x_l; ++i) {
                for (j = 0; j < m_y_l; ++j) {
                    double e_ij = 1.0 * m_x[i] * m_y[j] / m;
                    double o_ij = count[i][j];
                    if (!(e_ij > 0.0)) continue;
                    chi_square += (e_ij - o_ij) * (e_ij - o_ij) / e_ij;
                }
            }
            return chi_square > this.threshold + 1.0E-8;
        }
    }

    public static abstract class ImpurityFunction {
        public abstract double calc(int var1, int var2);
    }
}

