/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.hmm.train.bw;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.ml.hmm.train.bw.BaseBaumWelch;

public class TrainBaumWelch
extends BaseBaumWelch {
    public TrainBaumWelch(HiddenMarkovModel hmm, MLSequenceSet training) {
        super(hmm, training);
    }

    @Override
    protected double[][] estimateGamma(double[][][] xi, ForwardBackwardCalculator fbc) {
        int i;
        double[][] gamma = new double[xi.length + 1][xi[0].length];
        int t = 0;
        while (t < xi.length + 1) {
            Arrays.fill(gamma[t], 0.0);
            ++t;
        }
        t = 0;
        while (t < xi.length) {
            i = 0;
            while (i < xi[0].length) {
                int j = 0;
                while (j < xi[0].length) {
                    double[] dArray = gamma[t];
                    int n = i;
                    dArray[n] = dArray[n] + xi[t][i][j];
                    ++j;
                }
                ++i;
            }
            ++t;
        }
        int j = 0;
        while (j < xi[0].length) {
            i = 0;
            while (i < xi[0].length) {
                double[] dArray = gamma[xi.length];
                int n = j;
                dArray[n] = dArray[n] + xi[xi.length - 1][i][j];
                ++i;
            }
            ++j;
        }
        return gamma;
    }

    @Override
    public double[][][] estimateXi(MLDataSet sequence, ForwardBackwardCalculator fbc, HiddenMarkovModel hmm) {
        if (sequence.size() <= 1) {
            throw new IllegalArgumentException("Must have more than one observation");
        }
        double[][][] xi = new double[sequence.size() - 1][hmm.getStateCount()][hmm.getStateCount()];
        double probability = fbc.probability();
        Iterator seqIterator = sequence.iterator();
        seqIterator.next();
        int t = 0;
        while (t < sequence.size() - 1) {
            MLDataPair o = (MLDataPair)seqIterator.next();
            int i = 0;
            while (i < hmm.getStateCount()) {
                int j = 0;
                while (j < hmm.getStateCount()) {
                    xi[t][i][j] = fbc.alphaElement(t, i) * hmm.getTransitionProbability(i, j) * hmm.getStateDistribution(j).probability(o) * fbc.betaElement(t + 1, j) / probability;
                    ++j;
                }
                ++i;
            }
            ++t;
        }
        return xi;
    }

    @Override
    public ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet sequence, HiddenMarkovModel hmm) {
        return new ForwardBackwardCalculator(sequence, hmm, EnumSet.allOf(ForwardBackwardCalculator.Computation.class));
    }
}

