/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.hmm.alog;

import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;

public class ForwardBackwardCalculator {
    protected double[][] alpha = null;
    protected double[][] beta = null;
    protected double probability;

    protected ForwardBackwardCalculator() {
    }

    public ForwardBackwardCalculator(MLDataSet oseq, HiddenMarkovModel hmm) {
        this(oseq, hmm, EnumSet.of(Computation.ALPHA));
    }

    public ForwardBackwardCalculator(MLDataSet oseq, HiddenMarkovModel hmm, EnumSet<Computation> flags) {
        if (oseq.size() < 1) {
            throw new IllegalArgumentException("Empty sequence");
        }
        if (flags.contains((Object)Computation.ALPHA)) {
            this.computeAlpha(hmm, oseq);
        }
        if (flags.contains((Object)Computation.BETA)) {
            this.computeBeta(hmm, oseq);
        }
        this.computeProbability(oseq, hmm, flags);
    }

    public double alphaElement(int t, int i) {
        if (this.alpha == null) {
            throw new UnsupportedOperationException("Alpha array has not been computed");
        }
        return this.alpha[t][i];
    }

    public double betaElement(int t, int i) {
        if (this.beta == null) {
            throw new UnsupportedOperationException("Beta array has not been computed");
        }
        return this.beta[t][i];
    }

    protected void computeAlpha(HiddenMarkovModel hmm, MLDataSet oseq) {
        this.alpha = new double[oseq.size()][hmm.getStateCount()];
        int i = 0;
        while (i < hmm.getStateCount()) {
            this.computeAlphaInit(hmm, oseq.get(0), i);
            ++i;
        }
        Iterator seqIterator = oseq.iterator();
        if (seqIterator.hasNext()) {
            seqIterator.next();
        }
        int t = 1;
        while (t < oseq.size()) {
            MLDataPair observation = (MLDataPair)seqIterator.next();
            int i2 = 0;
            while (i2 < hmm.getStateCount()) {
                this.computeAlphaStep(hmm, observation, t, i2);
                ++i2;
            }
            ++t;
        }
    }

    protected void computeAlphaInit(HiddenMarkovModel hmm, MLDataPair o, int i) {
        this.alpha[0][i] = hmm.getPi(i) * hmm.getStateDistribution(i).probability(o);
    }

    protected void computeAlphaStep(HiddenMarkovModel hmm, MLDataPair o, int t, int j) {
        double sum = 0.0;
        int i = 0;
        while (i < hmm.getStateCount()) {
            sum += this.alpha[t - 1][i] * hmm.getTransitionProbability(i, j);
            ++i;
        }
        this.alpha[t][j] = sum * hmm.getStateDistribution(j).probability(o);
    }

    protected void computeBeta(HiddenMarkovModel hmm, MLDataSet oseq) {
        this.beta = new double[oseq.size()][hmm.getStateCount()];
        int i = 0;
        while (i < hmm.getStateCount()) {
            this.beta[oseq.size() - 1][i] = 1.0;
            ++i;
        }
        int t = oseq.size() - 2;
        while (t >= 0) {
            int i2 = 0;
            while (i2 < hmm.getStateCount()) {
                this.computeBetaStep(hmm, oseq.get(t + 1), t, i2);
                ++i2;
            }
            --t;
        }
    }

    protected void computeBetaStep(HiddenMarkovModel hmm, MLDataPair o, int t, int i) {
        double sum = 0.0;
        int j = 0;
        while (j < hmm.getStateCount()) {
            sum += this.beta[t + 1][j] * hmm.getTransitionProbability(i, j) * hmm.getStateDistribution(j).probability(o);
            ++j;
        }
        this.beta[t][i] = sum;
    }

    private void computeProbability(MLDataSet oseq, HiddenMarkovModel hmm, EnumSet<Computation> flags) {
        this.probability = 0.0;
        if (flags.contains((Object)Computation.ALPHA)) {
            int i = 0;
            while (i < hmm.getStateCount()) {
                this.probability += this.alpha[oseq.size() - 1][i];
                ++i;
            }
        } else {
            int i = 0;
            while (i < hmm.getStateCount()) {
                this.probability += hmm.getPi(i) * hmm.getStateDistribution(i).probability(oseq.get(0)) * this.beta[0][i];
                ++i;
            }
        }
    }

    public double probability() {
        return this.probability;
    }

    public static enum Computation {
        ALPHA,
        BETA;

    }
}

