/*
 * Decompiled with CFR 0.152.
 */
package bayesnet.jayes.inference.jtree;

import bayesnet.internal.jayes.util.ArrayUtils;
import bayesnet.jayes.BayesNet;
import bayesnet.jayes.BayesNode;
import bayesnet.jayes.factor.AbstractFactor;
import bayesnet.jayes.factor.arraywrapper.DoubleArrayWrapper;
import bayesnet.jayes.factor.arraywrapper.IArrayWrapper;
import bayesnet.jayes.inference.AbstractInferer;
import bayesnet.jayes.inference.jtree.JunctionTree;
import bayesnet.jayes.inference.jtree.JunctionTreeBuilder;
import bayesnet.jayes.util.Graph;
import bayesnet.jayes.util.MathUtils;
import bayesnet.jayes.util.NumericalInstabilityException;
import bayesnet.jayes.util.OrderIgnoringPair;
import bayesnet.jayes.util.Pair;
import bayesnet.jayes.util.sharing.CanonicalArrayWrapperManager;
import bayesnet.jayes.util.sharing.CanonicalIntArrayManager;
import bayesnet.jayes.util.triangulation.MinFillIn;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;

public class JunctionTreeAlgorithm
extends AbstractInferer {
    private static final double ONE = 1.0;
    private static final double ONE_LOG = 0.0;
    protected Map<OrderIgnoringPair<Integer>, AbstractFactor> sepSets;
    protected Graph junctionTree;
    protected AbstractFactor[] nodePotentials;
    protected Map<Pair<Integer, Integer>, int[]> preparedMultiplications;
    protected int[][] concernedClusters;
    protected AbstractFactor[] queryFactors;
    protected int[][] preparedQueries;
    protected boolean[] isBeliefValid;
    protected List<Pair<AbstractFactor, IArrayWrapper>> initializations;
    protected int[][] queryFactorReverseMapping;
    protected Set<Integer> clustersHavingEvidence;
    protected boolean[] isObserved;
    protected double[] scratchpad;
    protected JunctionTreeBuilder junctionTreeBuilder = JunctionTreeBuilder.forHeuristic(new MinFillIn());

    public void setJunctionTreeBuilder(JunctionTreeBuilder bldr) {
        this.junctionTreeBuilder = bldr;
    }

    @Override
    public double[] getBeliefs(BayesNode node) {
        int nodeId;
        if (!this.beliefsValid) {
            this.beliefsValid = true;
            this.updateBeliefs();
        }
        if (!this.isBeliefValid[nodeId = node.getId()]) {
            this.isBeliefValid[nodeId] = true;
            if (!this.evidence.containsKey(node)) {
                this.validateBelief(nodeId);
            } else {
                Arrays.fill(this.beliefs[nodeId], 0.0);
                this.beliefs[nodeId][node.getOutcomeIndex((String)((String)this.evidence.get((Object)node)))] = 1.0;
            }
        }
        return super.getBeliefs(node);
    }

    private void validateBelief(int nodeId) {
        AbstractFactor f = this.queryFactors[nodeId];
        f.sumPrepared(new DoubleArrayWrapper(this.beliefs[nodeId]), this.preparedQueries[nodeId]);
        if (f.isLogScale()) {
            MathUtils.exp(this.beliefs[nodeId]);
        }
        try {
            this.beliefs[nodeId] = MathUtils.normalize(this.beliefs[nodeId]);
        }
        catch (IllegalArgumentException exception) {
            throw new NumericalInstabilityException("Numerical instability detected for evidence: " + this.evidence + " and node : " + nodeId + ", consider using logarithmic scale computation (configurable in FactorFactory)", exception);
        }
    }

    @Override
    protected void updateBeliefs() {
        Arrays.fill(this.isBeliefValid, false);
        this.doUpdateBeliefs();
    }

    private void doUpdateBeliefs() {
        this.incorporateAllEvidence();
        int propagationRoot = this.findPropagationRoot();
        this.replayFactorInitializations();
        this.collectEvidence(propagationRoot, this.skipCollection(propagationRoot));
        this.distributeEvidence(propagationRoot, this.skipDistribution(propagationRoot));
    }

    private void replayFactorInitializations() {
        for (Pair<AbstractFactor, IArrayWrapper> init : this.initializations) {
            init.getFirst().copyValues(init.getSecond());
        }
    }

    private void incorporateAllEvidence() {
        for (Pair<AbstractFactor, IArrayWrapper> init : this.initializations) {
            init.getFirst().resetSelections();
        }
        this.clustersHavingEvidence.clear();
        Arrays.fill(this.isObserved, false);
        for (BayesNode n : this.evidence.keySet()) {
            this.incorporateEvidence(n);
        }
    }

    private void incorporateEvidence(BayesNode node) {
        int n = node.getId();
        this.isObserved[n] = true;
        int[] nArray = this.concernedClusters[n];
        int n2 = nArray.length;
        for (int i = 0; i < n2; ++i) {
            Integer concernedCluster = nArray[i];
            this.nodePotentials[concernedCluster].select(n, node.getOutcomeIndex((String)this.evidence.get(node)));
            this.clustersHavingEvidence.add(concernedCluster);
        }
    }

    private int findPropagationRoot() {
        int propagationRoot = 0;
        for (BayesNode n : this.evidence.keySet()) {
            propagationRoot = this.concernedClusters[n.getId()][0];
        }
        return propagationRoot;
    }

    private Set<Integer> skipCollection(int root) {
        HashSet<Integer> skipped = new HashSet<Integer>(this.nodePotentials.length);
        this.recursiveSkipCollection(root, new HashSet<Integer>(this.nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipCollection(int node, Set<Integer> visited, Set<Integer> skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (int neighbor : this.junctionTree.getNeighbors(node)) {
            if (visited.contains(neighbor)) continue;
            this.recursiveSkipCollection(neighbor, visited, skipped);
            if (skipped.contains(neighbor)) continue;
            areAllDescendantsSkipped = false;
        }
        if (areAllDescendantsSkipped && !this.clustersHavingEvidence.contains(node)) {
            skipped.add(node);
        }
    }

    private Set<Integer> skipDistribution(int distNode) {
        HashSet<Integer> skipped = new HashSet<Integer>(this.nodePotentials.length);
        this.recursiveSkipDistribution(distNode, new HashSet<Integer>(this.nodePotentials.length), skipped);
        return skipped;
    }

    private void recursiveSkipDistribution(int node, Set<Integer> visited, Set<Integer> skipped) {
        visited.add(node);
        boolean areAllDescendantsSkipped = true;
        for (Integer neighbor : this.junctionTree.getNeighbors(node)) {
            if (visited.contains(neighbor)) continue;
            this.recursiveSkipDistribution(neighbor, visited, skipped);
            if (skipped.contains(neighbor)) continue;
            areAllDescendantsSkipped = false;
        }
        if (areAllDescendantsSkipped && !this.isQueryFactorOfUnobservedVariable(node)) {
            skipped.add(node);
        }
    }

    private boolean isQueryFactorOfUnobservedVariable(int node) {
        for (int i : this.queryFactorReverseMapping[node]) {
            if (this.isObserved[i]) continue;
            return true;
        }
        return false;
    }

    private void collectEvidence(int cluster, Set<Integer> marked) {
        marked.add(cluster);
        for (int n : this.junctionTree.getNeighbors(cluster)) {
            if (marked.contains(n)) continue;
            this.collectEvidence(n, marked);
            this.messagePass(n, cluster);
        }
    }

    private void distributeEvidence(int cluster, Set<Integer> marked) {
        marked.add(cluster);
        for (int n : this.junctionTree.getNeighbors(cluster)) {
            if (marked.contains(n)) continue;
            this.messagePass(cluster, n);
            this.distributeEvidence(n, marked);
        }
    }

    private void messagePass(int v1, int v2) {
        OrderIgnoringPair<Integer> sepSetEdge = new OrderIgnoringPair<Integer>(v1, v2);
        AbstractFactor sepSet = this.sepSets.get(sepSetEdge);
        if (!this.needMessagePass(sepSet)) {
            return;
        }
        IArrayWrapper newSepValues = sepSet.getValues();
        System.arraycopy(newSepValues.toDoubleArray(), 0, this.scratchpad, 0, newSepValues.length());
        int[] preparedOp = this.preparedMultiplications.get(Pair.newPair(v2, v1));
        this.nodePotentials[sepSetEdge.getFirst()].sumPrepared(newSepValues, preparedOp);
        if (this.isOnlyFirstLogScale(sepSetEdge)) {
            MathUtils.exp(newSepValues);
        }
        if (this.areBothEndsLogScale(sepSetEdge)) {
            MathUtils.secureSubtract(newSepValues.toDoubleArray(), this.scratchpad, this.scratchpad);
        } else {
            MathUtils.secureDivide(newSepValues.toDoubleArray(), this.scratchpad, this.scratchpad);
        }
        if (this.isOnlySecondLogScale(sepSetEdge)) {
            MathUtils.log(this.scratchpad);
        }
        this.nodePotentials[sepSetEdge.getSecond()].multiplyPrepared(new DoubleArrayWrapper(this.scratchpad), this.preparedMultiplications.get(Pair.newPair(v1, v2)));
    }

    private boolean needMessagePass(AbstractFactor sepSet) {
        for (int var : sepSet.getDimensionIDs()) {
            if (this.isObserved[var]) continue;
            return true;
        }
        return false;
    }

    private boolean isOnlyFirstLogScale(OrderIgnoringPair<Integer> edge) {
        return this.nodePotentials[edge.getFirst()].isLogScale() && !this.nodePotentials[edge.getSecond()].isLogScale();
    }

    private boolean isOnlySecondLogScale(OrderIgnoringPair<Integer> edge) {
        return !this.nodePotentials[edge.getFirst()].isLogScale() && this.nodePotentials[edge.getSecond()].isLogScale();
    }

    @Override
    public void setNetwork(BayesNet net) {
        super.setNetwork(net);
        this.initializeFields(net.getNodes().size());
        JunctionTree jtree = this.buildJunctionTree(net);
        Map<AbstractFactor, Integer> homeClusters = this.computeHomeClusters(net, jtree.getClusters());
        this.initializeClusterFactors(net, jtree.getClusters(), homeClusters);
        this.initializeSepsetFactors(jtree.getSepSets());
        this.determineConcernedClusters();
        this.setQueryFactors();
        this.initializePotentialValues();
        this.multiplyCPTsIntoPotentials(net, homeClusters);
        this.prepareMultiplications();
        this.prepareScratch();
        this.invokeInitialBeliefUpdate();
        this.storePotentialValues();
    }

    private void determineConcernedClusters() {
        int i;
        this.concernedClusters = new int[this.queryFactors.length][];
        List[] temp = new List[this.concernedClusters.length];
        for (i = 0; i < temp.length; ++i) {
            temp[i] = new ArrayList();
        }
        for (i = 0; i < this.nodePotentials.length; ++i) {
            int[] dimensionIDs;
            for (int var : dimensionIDs = this.nodePotentials[i].getDimensionIDs()) {
                temp[var].add(i);
            }
        }
        for (i = 0; i < temp.length; ++i) {
            this.concernedClusters[i] = ArrayUtils.toIntArray(temp[i]);
        }
    }

    private void initializeFields(int numNodes) {
        this.isBeliefValid = new boolean[this.beliefs.length];
        Arrays.fill(this.isBeliefValid, false);
        this.queryFactors = new AbstractFactor[numNodes];
        this.preparedQueries = new int[numNodes][];
        this.sepSets = new HashMap<OrderIgnoringPair<Integer>, AbstractFactor>(numNodes);
        this.preparedMultiplications = new HashMap<Pair<Integer, Integer>, int[]>(numNodes);
        this.initializations = new ArrayList<Pair<AbstractFactor, IArrayWrapper>>();
        this.clustersHavingEvidence = new HashSet<Integer>(numNodes);
        this.isObserved = new boolean[numNodes];
    }

    private JunctionTree buildJunctionTree(BayesNet net) {
        JunctionTree jtree = this.junctionTreeBuilder.buildJunctionTree(net);
        this.junctionTree = jtree.getGraph();
        return jtree;
    }

    private Map<AbstractFactor, Integer> computeHomeClusters(BayesNet net, List<List<Integer>> clusters) {
        HashMap<AbstractFactor, Integer> homeClusters = new HashMap<AbstractFactor, Integer>();
        block0: for (BayesNode node : net.getNodes()) {
            int[] nodeAndParents = node.getFactor().getDimensionIDs();
            ListIterator<List<Integer>> clusterIt = clusters.listIterator();
            while (clusterIt.hasNext()) {
                if (!this.containsAll(clusterIt.next(), nodeAndParents)) continue;
                homeClusters.put(node.getFactor(), clusterIt.nextIndex() - 1);
                continue block0;
            }
        }
        return homeClusters;
    }

    private boolean containsAll(List<Integer> list, int[] ints) {
        for (int n : ints) {
            if (list.contains(n)) continue;
            return false;
        }
        return true;
    }

    private void initializeClusterFactors(BayesNet net, List<List<Integer>> clusters, Map<AbstractFactor, Integer> homeClusters) {
        this.nodePotentials = new AbstractFactor[clusters.size()];
        Map<Integer, List<AbstractFactor>> multiplicationPartners = this.findMultiplicationPartners(net, homeClusters);
        ListIterator<List<Integer>> cliqueIt = clusters.listIterator();
        while (cliqueIt.hasNext()) {
            AbstractFactor cliqueFactor;
            List<Integer> cluster = cliqueIt.next();
            int current = cliqueIt.nextIndex() - 1;
            List<AbstractFactor> multiplicationPartnerList = multiplicationPartners.get(current);
            this.nodePotentials[current] = cliqueFactor = this.factory.create(cluster, multiplicationPartnerList == null ? Collections.emptyList() : multiplicationPartnerList);
        }
    }

    private Map<Integer, List<AbstractFactor>> findMultiplicationPartners(BayesNet net, Map<AbstractFactor, Integer> homeClusters) {
        HashMap<Integer, List<AbstractFactor>> potentialMap = new HashMap<Integer, List<AbstractFactor>>();
        for (BayesNode node : net.getNodes()) {
            Integer nodeHome = homeClusters.get(node.getFactor());
            if (!potentialMap.containsKey(nodeHome)) {
                potentialMap.put(nodeHome, new ArrayList());
            }
            ((List)potentialMap.get(nodeHome)).add(node.getFactor());
        }
        return potentialMap;
    }

    private void initializeSepsetFactors(List<Pair<OrderIgnoringPair<Integer>, List<Integer>>> sepSets) {
        for (Pair<OrderIgnoringPair<Integer>, List<Integer>> sep : sepSets) {
            this.sepSets.put(sep.getFirst(), this.factory.create(sep.getSecond(), Collections.emptyList()));
        }
    }

    private void setQueryFactors() {
        int i;
        for (i = 0; i < this.queryFactors.length; ++i) {
            int[] nArray = this.concernedClusters[i];
            int n = nArray.length;
            for (int j = 0; j < n; ++j) {
                boolean isFirstOrSmallerTable;
                Integer f = nArray[j];
                boolean bl = isFirstOrSmallerTable = this.queryFactors[i] == null || this.queryFactors[i].getValues().length() > this.nodePotentials[f].getValues().length();
                if (!isFirstOrSmallerTable) continue;
                this.queryFactors[i] = this.nodePotentials[f];
            }
        }
        this.queryFactorReverseMapping = new int[this.nodePotentials.length][];
        for (i = 0; i < this.nodePotentials.length; ++i) {
            ArrayList<Integer> queryVars = new ArrayList<Integer>();
            for (int var : this.nodePotentials[i].getDimensionIDs()) {
                if (this.queryFactors[var] != this.nodePotentials[i]) continue;
                queryVars.add(var);
            }
            this.queryFactorReverseMapping[i] = ArrayUtils.toIntArray(queryVars);
        }
    }

    private void prepareMultiplications() {
        CanonicalIntArrayManager flyWeight = new CanonicalIntArrayManager();
        this.prepareSepsetMultiplications(flyWeight);
        this.prepareQueries(flyWeight);
    }

    private void prepareSepsetMultiplications(CanonicalIntArrayManager flyWeight) {
        for (int node = 0; node < this.nodePotentials.length; ++node) {
            for (int n : this.junctionTree.getNeighbors(node)) {
                int[] preparedMultiplication = this.nodePotentials[n].prepareMultiplication(this.sepSets.get(new OrderIgnoringPair<Integer>(node, n)));
                this.preparedMultiplications.put(Pair.newPair(node, n), flyWeight.getInstance(preparedMultiplication));
            }
        }
    }

    private void prepareQueries(CanonicalIntArrayManager flyWeight) {
        for (int i = 0; i < this.queryFactors.length; ++i) {
            AbstractFactor beliefFactor = this.factory.create(Arrays.asList(i), Collections.emptyList());
            int[] preparedQuery = this.queryFactors[i].prepareMultiplication(beliefFactor);
            this.preparedQueries[i] = flyWeight.getInstance(preparedQuery);
        }
    }

    private void prepareScratch() {
        int maxSize = 0;
        for (AbstractFactor sepSet : this.sepSets.values()) {
            maxSize = Math.max(maxSize, sepSet.getValues().length());
        }
        this.scratchpad = new double[maxSize];
    }

    private void invokeInitialBeliefUpdate() {
        this.collectEvidence(0, new HashSet<Integer>());
        this.distributeEvidence(0, new HashSet<Integer>());
    }

    private void initializePotentialValues() {
        for (AbstractFactor f : this.nodePotentials) {
            f.fill(f.isLogScale() ? 0.0 : 1.0);
        }
        for (Map.Entry entry : this.sepSets.entrySet()) {
            if (!this.areBothEndsLogScale((OrderIgnoringPair)entry.getKey())) {
                ((AbstractFactor)entry.getValue()).fill(1.0);
                continue;
            }
            ((AbstractFactor)entry.getValue()).fill(0.0);
        }
    }

    private void multiplyCPTsIntoPotentials(BayesNet net, Map<AbstractFactor, Integer> homeClusters) {
        for (BayesNode node : net.getNodes()) {
            AbstractFactor nodeHome = this.nodePotentials[homeClusters.get(node.getFactor())];
            if (nodeHome.isLogScale()) {
                nodeHome.multiplyCompatibleToLog(node.getFactor());
                continue;
            }
            nodeHome.multiplyCompatible(node.getFactor());
        }
    }

    private boolean areBothEndsLogScale(OrderIgnoringPair<Integer> edge) {
        return this.nodePotentials[edge.getFirst()].isLogScale() && this.nodePotentials[edge.getSecond()].isLogScale();
    }

    private void storePotentialValues() {
        CanonicalArrayWrapperManager flyweight = new CanonicalArrayWrapperManager();
        for (AbstractFactor pot : this.nodePotentials) {
            this.initializations.add(Pair.newPair(pot, flyweight.getInstance(pot.getValues().clone())));
        }
        for (AbstractFactor sep : this.sepSets.values()) {
            this.initializations.add(Pair.newPair(sep, flyweight.getInstance(sep.getValues().clone())));
        }
    }
}

