package bayesnet.jayes.transformation;

import bayesnet.jayes.BayesNet;
import bayesnet.jayes.BayesNode;
import bayesnet.jayes.factor.AbstractFactor;
import bayesnet.jayes.factor.DenseFactor;
import bayesnet.jayes.factor.arraywrapper.DoubleArrayWrapper;
import bayesnet.jayes.transformation.util.ArrayFlatten;
import bayesnet.jayes.transformation.util.CanonicalDoubleArrayManager;
import bayesnet.jayes.transformation.util.DecompositionFailedException;
import bayesnet.jayes.util.MathUtils;
import com.google.common.collect.Lists;
import com.google.common.primitives.Ints;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:bayesnet/jayes/transformation/AbstractDecomposition.class */
public abstract class AbstractDecomposition implements IDecompositionStrategy {
    @Override // bayesnet.jayes.transformation.IDecompositionStrategy
    public final void decompose(BayesNet bayesNet, BayesNode bayesNode) throws DecompositionFailedException {
        if (!bayesNet.getNodes().contains(bayesNode)) {
            throw new IllegalArgumentException("Node " + bayesNode + " is not part of the bayesnet " + bayesNet.getName());
        }
        AbstractFactor factor = bayesNode.getFactor();
        if (factor.getDimensions().length == 1) {
            throw new DecompositionFailedException("Node " + bayesNode + " has no parents, impossible to decompose");
        }
        AbstractFactor reorderFactor = reorderFactor(factor);
        int[] dimensions = reorderFactor.getDimensions();
        List<double[]> unflatten = ArrayFlatten.unflatten(reorderFactor.getValues().toDoubleArray(), dimensions[dimensions.length - 1]);
        List<double[]> basis = getBasis(reorderFactor, unflatten);
        double[] latentProbabilities = getLatentProbabilities(unflatten, basis);
        if (reorderFactor == bayesNode.getFactor()) {
            createLatentNodeInOriginalOrder(bayesNet, bayesNode, basis, latentProbabilities);
        } else {
            createLatentNodeReordered(bayesNet, bayesNode, reorderFactor, basis, latentProbabilities);
        }
    }

    private AbstractFactor reorderFactor(AbstractFactor abstractFactor) {
        int[] dimensions = abstractFactor.getDimensions();
        int lastIndexOf = Ints.lastIndexOf(dimensions, Ints.min(dimensions));
        if (lastIndexOf == dimensions.length - 1) {
            return abstractFactor;
        }
        int[] rotateRight = rotateRight(dimensions, (dimensions.length - 1) - lastIndexOf);
        int[] rotateRight2 = rotateRight(abstractFactor.getDimensionIDs(), (dimensions.length - 1) - lastIndexOf);
        DenseFactor denseFactor = new DenseFactor();
        denseFactor.setDimensionIDs(rotateRight2);
        denseFactor.setDimensions(rotateRight);
        denseFactor.fill(1.0d);
        denseFactor.multiplyCompatible(abstractFactor);
        return denseFactor;
    }

    protected abstract List<double[]> getBasis(AbstractFactor abstractFactor, List<double[]> list) throws DecompositionFailedException;

    private double[] getLatentProbabilities(List<double[]> list, List<double[]> list2) throws DecompositionFailedException {
        CanonicalDoubleArrayManager canonicalDoubleArrayManager = new CanonicalDoubleArrayManager();
        return ArrayFlatten.flatten((double[][]) toLatentSpace(Lists.transform(list, canonicalDoubleArrayManager), Lists.transform(list2, canonicalDoubleArrayManager)).toArray((Object[]) new double[0]));
    }

    private List<double[]> toLatentSpace(List<double[]> list, List<double[]> list2) throws DecompositionFailedException {
        ArrayList arrayList = new ArrayList();
        Iterator<double[]> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(toLatentSpace(it.next(), list2));
        }
        return arrayList;
    }

    protected abstract double[] toLatentSpace(double[] dArr, List<double[]> list) throws DecompositionFailedException;

    private void createLatentNodeInOriginalOrder(BayesNet bayesNet, BayesNode bayesNode, List<double[]> list, double[] dArr) {
        BayesNode createNode = bayesNet.createNode("latent-" + bayesNode.getName());
        addOutcomes(createNode, list.size());
        createNode.setParents(bayesNode.getParents());
        createNode.setProbabilities(dArr);
        bayesNode.setParents(Arrays.asList(createNode));
        bayesNode.setProbabilities(ArrayFlatten.flatten((double[][]) list.toArray((Object[]) new double[0])));
    }

    private void createLatentNodeReordered(BayesNet bayesNet, BayesNode bayesNode, AbstractFactor abstractFactor, List<double[]> list, double[] dArr) {
        BayesNode createNode = bayesNet.createNode("latent-" + bayesNode.getName());
        addOutcomes(createNode, list.size());
        BayesNode node = bayesNet.getNode(abstractFactor.getDimensionIDs()[abstractFactor.getDimensions().length - 1]);
        createNode.setParents(Arrays.asList(node));
        createNode.setProbabilities(ArrayFlatten.flatten((double[][]) transpose(list).toArray((Object[]) new double[0])));
        ArrayList arrayList = new ArrayList(bayesNode.getParents());
        int indexOf = arrayList.indexOf(node);
        arrayList.remove(node);
        arrayList.add(indexOf, createNode);
        bayesNode.setParents(arrayList);
        bayesNode.setProbabilities(undoReordering(dArr, bayesNode.getFactor(), abstractFactor, createNode.getId()));
    }

    private void addOutcomes(BayesNode bayesNode, int i) {
        for (int i2 = 0; i2 < i; i2++) {
            bayesNode.addOutcome("outcome" + i2);
        }
    }

    private double[] undoReordering(double[] dArr, AbstractFactor abstractFactor, AbstractFactor abstractFactor2, int i) {
        AbstractFactor mo2clone = abstractFactor.mo2clone();
        AbstractFactor mo2clone2 = abstractFactor2.mo2clone();
        mo2clone2.getDimensionIDs()[mo2clone2.getDimensionIDs().length - 1] = i;
        mo2clone2.getDimensions()[mo2clone2.getDimensions().length - 1] = mo2clone.getDimensions()[Ints.indexOf(mo2clone.getDimensionIDs(), i)];
        mo2clone2.setValues(new DoubleArrayWrapper(dArr));
        mo2clone.setValues(new DoubleArrayWrapper(new double[MathUtils.product(mo2clone.getDimensions())]));
        mo2clone.fill(1.0d);
        mo2clone.multiplyCompatible(mo2clone2);
        return mo2clone.getValues().toDoubleArray();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final List<double[]> transpose(List<double[]> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.get(0).length; i++) {
            arrayList.add(new double[list.size()]);
        }
        for (int i2 = 0; i2 < list.size(); i2++) {
            double[] dArr = list.get(i2);
            for (int i3 = 0; i3 < dArr.length; i3++) {
                ((double[]) arrayList.get(i3))[i2] = dArr[i3];
            }
        }
        return arrayList;
    }

    private int[] rotateRight(int[] iArr, int i) {
        int[] iArr2 = new int[iArr.length];
        System.arraycopy(iArr, 0, iArr2, i, iArr.length - i);
        System.arraycopy(iArr, iArr.length - i, iArr2, 0, i);
        return iArr2;
    }
}
