package dt;

import java.io.BufferedReader;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:dt/ID3.class */
public class ID3 {
    Node root;
    public static ImpurityFunction impurity_entropy = new ImpurityFunction() { // from class: dt.ID3.1
        @Override // dt.ID3.ImpurityFunction
        public double calc(int i, int i2) {
            double d = i / (i + i2);
            double d2 = i2 / (i + i2);
            double d3 = 0.0d;
            if (i > 0) {
                d3 = 0.0d + ((-d) * FastMath.log(d));
            }
            if (i2 > 0) {
                d3 += (-d2) * FastMath.log(d2);
            }
            return d3 / FastMath.log(2.0d);
        }
    };
    public static ImpurityFunction impurity_misclassification = new ImpurityFunction() { // from class: dt.ID3.2
        @Override // dt.ID3.ImpurityFunction
        public double calc(int i, int i2) {
            return i > i2 ? i2 / (i + i2) : i / (i + i2);
        }
    };
    public static ChiSquareTest chi_square_001 = new ChiSquareTest(16.27d);
    public static ChiSquareTest chi_square_01 = new ChiSquareTest(11.34d);
    public static ChiSquareTest chi_square_05 = new ChiSquareTest(7.82d);
    public static ChiSquareTest chi_square_100 = new ChiSquareTest(0.0d);

    /* loaded from: input_file:dt/ID3$ChiSquareTest.class */
    public static class ChiSquareTest {
        double threshold;

        ChiSquareTest(double d) {
            this.threshold = d;
        }

        public boolean test_old(int[][] iArr) {
            double d = 0.0d;
            double d2 = iArr[0][0] + iArr[0][1] + iArr[1][0] + iArr[1][1];
            for (int i = 0; i < 2; i++) {
                double d3 = ((iArr[0][0] + iArr[0][1]) / d2) * (iArr[0][i] + iArr[1][i]);
                double d4 = iArr[0][i];
                d += ((d4 - d3) * (d4 - d3)) / d3;
            }
            return d > this.threshold + 1.0E-8d;
        }

        public boolean test(int[][] iArr) {
            double d = 0.0d;
            int length = iArr.length;
            int length2 = iArr[0].length;
            double[] dArr = new double[length];
            double[] dArr2 = new double[length2];
            double d2 = 0.0d;
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < length2; i2++) {
                    int i3 = i;
                    dArr[i3] = dArr[i3] + iArr[i][i2];
                    d2 += iArr[i][i2];
                }
            }
            for (int i4 = 0; i4 < length2; i4++) {
                for (int[] iArr2 : iArr) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + iArr2[i4];
                }
            }
            for (int i6 = 0; i6 < length; i6++) {
                for (int i7 = 0; i7 < length2; i7++) {
                    double d3 = ((1.0d * dArr[i6]) * dArr2[i7]) / d2;
                    double d4 = iArr[i6][i7];
                    if (d3 > 0.0d) {
                        d += ((d3 - d4) * (d3 - d4)) / d3;
                    }
                }
            }
            return d > this.threshold + 1.0E-8d;
        }
    }

    /* loaded from: input_file:dt/ID3$ImpurityFunction.class */
    public static abstract class ImpurityFunction {
        public abstract double calc(int i, int i2);
    }

    public static Node generate(List<Instance> list, ImpurityFunction impurityFunction) {
        Node node = new Node(null, list);
        expand(node, impurityFunction, chi_square_100, 0);
        return node;
    }

    public static Node generate(List<Instance> list, ImpurityFunction impurityFunction, ChiSquareTest chiSquareTest) {
        Node node = new Node(null, list);
        expand(node, impurityFunction, chiSquareTest, 0);
        return node;
    }

    static void expand(Node node, ImpurityFunction impurityFunction, ChiSquareTest chiSquareTest, int i) {
        double d = -100000.0d;
        int i2 = -1;
        node.instances.size();
        int length = node.instances.get(0).fts.length;
        int[][] iArr = new int[Instance.FTSVALUERANGE][2];
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < node.instances.size(); i5++) {
            if (node.instances.get(i5).label == 1) {
                i3++;
            } else {
                i4++;
            }
        }
        for (int i6 = 0; i6 < node.numOfFts; i6++) {
            int[][] iArr2 = new int[Instance.FTSVALUERANGE][2];
            for (Instance instance : node.instances) {
                if (instance.label == 1) {
                    int[] iArr3 = iArr2[instance.fts[i6]];
                    iArr3[1] = iArr3[1] + 1;
                } else {
                    int[] iArr4 = iArr2[instance.fts[i6]];
                    iArr4[0] = iArr4[0] + 1;
                }
            }
            double calc = impurityFunction.calc(i3, i4);
            for (int i7 = 0; i7 < Instance.FTSVALUERANGE; i7++) {
                calc -= ((1.0d * (iArr2[i7][0] + iArr2[i7][1])) / (i3 + i4)) * impurityFunction.calc(iArr2[i7][0], iArr2[i7][1]);
            }
            if (calc > d) {
                d = calc;
                i2 = i6;
                for (int i8 = 0; i8 < Instance.FTSVALUERANGE; i8++) {
                    iArr[i8][0] = iArr2[i8][0];
                    iArr[i8][1] = iArr2[i8][1];
                }
            }
        }
        if (d <= 1.0E-10d || !chiSquareTest.test(iArr)) {
            return;
        }
        node.testFts = i2;
        ArrayList arrayList = new ArrayList();
        for (int i9 = 0; i9 < Instance.FTSVALUERANGE; i9++) {
            arrayList.add(new ArrayList());
        }
        for (Instance instance2 : node.instances) {
            ((ArrayList) arrayList.get(instance2.fts[i2])).add(instance2);
        }
        for (int i10 = 0; i10 < Instance.FTSVALUERANGE; i10++) {
            if (i2 != 16 || i10 == 2) {
            }
            if (((ArrayList) arrayList.get(i10)).size() > 0) {
                node.children[i10] = new Node(node, (List) arrayList.get(i10));
                expand(node.children[i10], impurityFunction, chiSquareTest, i + 1);
            }
        }
    }

    public void learn(ArrayList<Instance> arrayList, ImpurityFunction impurityFunction, ChiSquareTest chiSquareTest) {
        this.root = generate(arrayList, impurityFunction, chiSquareTest);
    }

    public void learn(ArrayList<Instance> arrayList, ImpurityFunction impurityFunction) {
        this.root = generate(arrayList, impurityFunction);
    }

    public void learn(List<Instance> list) {
        this.root = generate(list, impurity_entropy);
    }

    public List<Integer> classify(List<Instance> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Instance> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(Integer.valueOf(this.root.classify(it.next())));
        }
        return arrayList;
    }

    public static void load(String str, String str2, List<Instance> list, List<Instance> list2) {
        int i = 0;
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                int i2 = i;
                i++;
                list.add(new Instance(readLine, i2));
            }
            bufferedReader.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
        try {
            BufferedReader bufferedReader2 = new BufferedReader(new FileReader(str2));
            while (true) {
                String readLine2 = bufferedReader2.readLine();
                if (readLine2 == null) {
                    bufferedReader2.close();
                    return;
                } else {
                    int i3 = i;
                    i++;
                    list2.add(new Instance(readLine2, i3));
                }
            }
        } catch (Exception e2) {
            e2.printStackTrace();
        }
    }

    public static double computeAccuracy(List<Integer> list, List<Instance> list2) {
        if (list.size() != list2.size()) {
            return 0.0d;
        }
        int i = 0;
        int i2 = 0;
        for (int i3 = 0; i3 < list.size(); i3++) {
            if (list.get(i3) == null) {
                i2++;
            } else if (list.get(i3).intValue() == list2.get(i3).label) {
                i++;
            } else {
                i2++;
            }
        }
        return (i * 1.0d) / (i + i2);
    }

    public static void main(String[] strArr) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        load("training.txt", "test.txt", arrayList, arrayList2);
        ID3 id3 = new ID3();
        id3.learn(arrayList);
        id3.root.writeTreeXML("tree_fulltree.xml");
        System.out.println("ID3 with full tree on training\t" + computeAccuracy(id3.classify(arrayList), arrayList));
        System.out.println("ID3 with full tree on test\t" + computeAccuracy(id3.classify(arrayList2), arrayList2));
    }
}
