package org.encog.ml.bayesian.training.search.k2;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.encog.mathutil.EncogMath;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.query.enumerate.EnumerationQuery;
import org.encog.ml.bayesian.training.TrainBayesian;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;

/* loaded from: input_file:org/encog/ml/bayesian/training/search/k2/SearchK2.class */
public class SearchK2 implements BayesSearch {
    private MLDataSet data;
    private BayesianNetwork network;
    private TrainBayesian train;
    private double lastCalculatedP;
    private final List<BayesianEvent> nodeOrdering = new ArrayList();
    private int index = -1;

    @Override // org.encog.ml.bayesian.training.search.k2.BayesSearch
    public void init(TrainBayesian trainBayesian, BayesianNetwork bayesianNetwork, MLDataSet mLDataSet) {
        this.network = bayesianNetwork;
        this.data = mLDataSet;
        this.train = trainBayesian;
        orderNodes();
        this.index = -1;
    }

    private void orderNodes() {
        this.nodeOrdering.clear();
        if (this.network.getClassificationTarget() != -1) {
            this.nodeOrdering.add(this.network.getClassificationTargetEvent());
        }
        for (BayesianEvent bayesianEvent : this.network.getEvents()) {
            if (!this.nodeOrdering.contains(bayesianEvent)) {
                this.nodeOrdering.add(bayesianEvent);
            }
        }
    }

    private BayesianEvent findZ(BayesianEvent bayesianEvent, int i, double d) {
        BayesianEvent bayesianEvent2 = null;
        double d2 = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < i; i2++) {
            BayesianEvent bayesianEvent3 = this.nodeOrdering.get(i2);
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(bayesianEvent.getParents());
            arrayList.add(bayesianEvent3);
            this.lastCalculatedP = calculateG(this.network, bayesianEvent, arrayList);
            if (this.lastCalculatedP > d && this.lastCalculatedP > d2) {
                bayesianEvent2 = bayesianEvent3;
                d2 = this.lastCalculatedP;
            }
        }
        this.lastCalculatedP = d2;
        return bayesianEvent2;
    }

    public int calculateN(BayesianNetwork bayesianNetwork, BayesianEvent bayesianEvent, List<BayesianEvent> list, int[] iArr, int i) {
        int i2 = 0;
        int eventIndex = bayesianNetwork.getEventIndex(bayesianEvent);
        Iterator<MLDataPair> it = this.data.iterator();
        while (it.hasNext()) {
            int[] determineClasses = this.network.determineClasses(it.next().getInput());
            if (determineClasses[eventIndex] == i) {
                boolean z = false;
                int i3 = 0;
                while (true) {
                    if (i3 >= iArr.length) {
                        break;
                    }
                    if (iArr[i3] != determineClasses[bayesianNetwork.getEventIndex(list.get(i3))]) {
                        z = true;
                        break;
                    }
                    i3++;
                }
                if (!z) {
                    i2++;
                }
            }
        }
        return i2;
    }

    public int calculateN(BayesianNetwork bayesianNetwork, BayesianEvent bayesianEvent, List<BayesianEvent> list, int[] iArr) {
        int i = 0;
        Iterator<MLDataPair> it = this.data.iterator();
        while (it.hasNext()) {
            int[] determineClasses = this.network.determineClasses(it.next().getInput());
            boolean z = false;
            int i2 = 0;
            while (true) {
                if (i2 >= iArr.length) {
                    break;
                }
                if (iArr[i2] != determineClasses[bayesianNetwork.getEventIndex(list.get(i2))]) {
                    z = true;
                    break;
                }
                i2++;
            }
            if (!z) {
                i++;
            }
        }
        return i;
    }

    public double calculateG(BayesianNetwork bayesianNetwork, BayesianEvent bayesianEvent, List<BayesianEvent> list) {
        double d = 1.0d;
        int size = bayesianEvent.getChoices().size();
        int[] iArr = new int[list.size()];
        do {
            double factorial = EncogMath.factorial(size - 1) / EncogMath.factorial((calculateN(bayesianNetwork, bayesianEvent, list, iArr) + size) - 1);
            double d2 = 1.0d;
            for (int i = 0; i < bayesianEvent.getChoices().size(); i++) {
                d2 *= EncogMath.factorial(calculateN(bayesianNetwork, bayesianEvent, list, iArr, i));
            }
            d *= factorial * d2;
        } while (EnumerationQuery.roll(list, iArr));
        return d;
    }

    @Override // org.encog.ml.bayesian.training.search.k2.BayesSearch
    public boolean iteration() {
        BayesianEvent findZ;
        if (this.index != -1) {
            BayesianEvent bayesianEvent = this.nodeOrdering.get(this.index);
            double calculateG = calculateG(this.network, bayesianEvent, bayesianEvent.getParents());
            while (true) {
                double d = calculateG;
                if (bayesianEvent.getParents().size() >= this.train.getMaximumParents() || (findZ = findZ(bayesianEvent, this.index, d)) == null) {
                    break;
                }
                this.network.createDependency(findZ, bayesianEvent);
                calculateG = this.lastCalculatedP;
            }
        } else {
            orderNodes();
        }
        this.index++;
        return this.index < this.data.getInputSize();
    }
}
