/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.bayesian.training.search.k2;

import java.util.ArrayList;
import java.util.List;
import org.encog.mathutil.EncogMath;
import org.encog.ml.bayesian.BayesianEvent;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.bayesian.query.enumerate.EnumerationQuery;
import org.encog.ml.bayesian.training.TrainBayesian;
import org.encog.ml.bayesian.training.search.k2.BayesSearch;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;

public class SearchK2
implements BayesSearch {
    private MLDataSet data;
    private BayesianNetwork network;
    private TrainBayesian train;
    private double lastCalculatedP;
    private final List<BayesianEvent> nodeOrdering = new ArrayList<BayesianEvent>();
    private int index = -1;

    @Override
    public void init(TrainBayesian theTrainer, BayesianNetwork theNetwork, MLDataSet theData) {
        this.network = theNetwork;
        this.data = theData;
        this.train = theTrainer;
        this.orderNodes();
        this.index = -1;
    }

    private void orderNodes() {
        this.nodeOrdering.clear();
        if (this.network.getClassificationTarget() != -1) {
            this.nodeOrdering.add(this.network.getClassificationTargetEvent());
        }
        for (BayesianEvent event : this.network.getEvents()) {
            if (this.nodeOrdering.contains(event)) continue;
            this.nodeOrdering.add(event);
        }
    }

    private BayesianEvent findZ(BayesianEvent event, int n, double old) {
        BayesianEvent result = null;
        double maxChildP = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < n) {
            BayesianEvent trialParent = this.nodeOrdering.get(i);
            ArrayList<BayesianEvent> parents = new ArrayList<BayesianEvent>();
            parents.addAll(event.getParents());
            parents.add(trialParent);
            this.lastCalculatedP = this.calculateG(this.network, event, parents);
            if (this.lastCalculatedP > old && this.lastCalculatedP > maxChildP) {
                result = trialParent;
                maxChildP = this.lastCalculatedP;
            }
            ++i;
        }
        this.lastCalculatedP = maxChildP;
        return result;
    }

    public int calculateN(BayesianNetwork network, BayesianEvent event, List<BayesianEvent> parents, int[] parentInstance, int desiredValue) {
        int result = 0;
        int eventIndex = network.getEventIndex(event);
        for (MLDataPair pair : this.data) {
            int[] d = this.network.determineClasses(pair.getInput());
            if (d[eventIndex] != desiredValue) continue;
            boolean reject = false;
            int i = 0;
            while (i < parentInstance.length) {
                BayesianEvent parentEvent = parents.get(i);
                int parentIndex = network.getEventIndex(parentEvent);
                if (parentInstance[i] != d[parentIndex]) {
                    reject = true;
                    break;
                }
                ++i;
            }
            if (reject) continue;
            ++result;
        }
        return result;
    }

    public int calculateN(BayesianNetwork network, BayesianEvent event, List<BayesianEvent> parents, int[] parentInstance) {
        int result = 0;
        for (MLDataPair pair : this.data) {
            int[] d = this.network.determineClasses(pair.getInput());
            boolean reject = false;
            int i = 0;
            while (i < parentInstance.length) {
                BayesianEvent parentEvent = parents.get(i);
                int parentIndex = network.getEventIndex(parentEvent);
                if (parentInstance[i] != d[parentIndex]) {
                    reject = true;
                    break;
                }
                ++i;
            }
            if (reject) continue;
            ++result;
        }
        return result;
    }

    public double calculateG(BayesianNetwork network, BayesianEvent event, List<BayesianEvent> parents) {
        double result = 1.0;
        int r = event.getChoices().size();
        int[] args = new int[parents.size()];
        do {
            double n = EncogMath.factorial(r - 1);
            double d = EncogMath.factorial(this.calculateN(network, event, parents, args) + r - 1);
            double p1 = n / d;
            double p2 = 1.0;
            int k = 0;
            while (k < event.getChoices().size()) {
                p2 *= EncogMath.factorial(this.calculateN(network, event, parents, args, k));
                ++k;
            }
            result *= p1 * p2;
        } while (EnumerationQuery.roll(parents, args));
        return result;
    }

    @Override
    public boolean iteration() {
        if (this.index == -1) {
            this.orderNodes();
        } else {
            BayesianEvent event = this.nodeOrdering.get(this.index);
            double oldP = this.calculateG(this.network, event, event.getParents());
            while (event.getParents().size() < this.train.getMaximumParents()) {
                BayesianEvent z = this.findZ(event, this.index, oldP);
                if (z == null) break;
                this.network.createDependency(z, event);
                oldP = this.lastCalculatedP;
            }
        }
        ++this.index;
        return this.index < this.data.getInputSize();
    }
}

