/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.hmm.train.bw;

import java.util.Arrays;
import java.util.List;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.ml.hmm.distributions.StateDistribution;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public abstract class BaseBaumWelch
implements MLTrain {
    private int iterations;
    private HiddenMarkovModel method;
    private final MLSequenceSet training;

    public BaseBaumWelch(HiddenMarkovModel hmm, MLSequenceSet training) {
        this.method = hmm;
        this.training = training;
    }

    @Override
    public void addStrategy(Strategy strategy) {
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    protected double[][] estimateGamma(double[][][] xi, ForwardBackwardCalculator fbc) {
        int i;
        double[][] gamma = new double[xi.length + 1][xi[0].length];
        int t = 0;
        while (t < xi.length + 1) {
            Arrays.fill(gamma[t], 0.0);
            ++t;
        }
        t = 0;
        while (t < xi.length) {
            i = 0;
            while (i < xi[0].length) {
                int j = 0;
                while (j < xi[0].length) {
                    double[] dArray = gamma[t];
                    int n = i;
                    dArray[n] = dArray[n] + xi[t][i][j];
                    ++j;
                }
                ++i;
            }
            ++t;
        }
        int j = 0;
        while (j < xi[0].length) {
            i = 0;
            while (i < xi[0].length) {
                double[] dArray = gamma[xi.length];
                int n = j;
                dArray[n] = dArray[n] + xi[xi.length - 1][i][j];
                ++i;
            }
            ++j;
        }
        return gamma;
    }

    public abstract double[][][] estimateXi(MLDataSet var1, ForwardBackwardCalculator var2, HiddenMarkovModel var3);

    @Override
    public void finishTraining() {
    }

    public abstract ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet var1, HiddenMarkovModel var2);

    @Override
    public double getError() {
        return 0.0;
    }

    @Override
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override
    public int getIteration() {
        return this.iterations;
    }

    @Override
    public MLMethod getMethod() {
        return this.method;
    }

    @Override
    public List<Strategy> getStrategies() {
        return null;
    }

    @Override
    public MLDataSet getTraining() {
        return this.training;
    }

    @Override
    public boolean isTrainingDone() {
        return false;
    }

    @Override
    public void iteration() {
        HiddenMarkovModel nhmm;
        try {
            nhmm = this.method.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new InternalError();
        }
        double[][][] allGamma = new double[this.training.getSequenceCount()][][];
        double[][] aijNum = new double[this.method.getStateCount()][this.method.getStateCount()];
        double[] aijDen = new double[this.method.getStateCount()];
        Arrays.fill(aijDen, 0.0);
        int i = 0;
        while (i < this.method.getStateCount()) {
            Arrays.fill(aijNum[i], 0.0);
            ++i;
        }
        int g = 0;
        for (MLDataSet obsSeq : this.training.getSequences()) {
            ForwardBackwardCalculator fbc = this.generateForwardBackwardCalculator(obsSeq, this.method);
            double[][][] xi = this.estimateXi(obsSeq, fbc, this.method);
            int n = g++;
            double[][] dArray = this.estimateGamma(xi, fbc);
            allGamma[n] = dArray;
            double[][] gamma = dArray;
            int i2 = 0;
            while (i2 < this.method.getStateCount()) {
                int t = 0;
                while (t < obsSeq.size() - 1) {
                    int n2 = i2;
                    aijDen[n2] = aijDen[n2] + gamma[t][i2];
                    int j = 0;
                    while (j < this.method.getStateCount()) {
                        double[] dArray2 = aijNum[i2];
                        int n3 = j;
                        dArray2[n3] = dArray2[n3] + xi[t][i2][j];
                        ++j;
                    }
                    ++t;
                }
                ++i2;
            }
        }
        int i3 = 0;
        while (i3 < this.method.getStateCount()) {
            int j;
            if (aijDen[i3] == 0.0) {
                j = 0;
                while (j < this.method.getStateCount()) {
                    nhmm.setTransitionProbability(i3, j, this.method.getTransitionProbability(i3, j));
                    ++j;
                }
            } else {
                j = 0;
                while (j < this.method.getStateCount()) {
                    nhmm.setTransitionProbability(i3, j, aijNum[i3][j] / aijDen[i3]);
                    ++j;
                }
            }
            ++i3;
        }
        i3 = 0;
        while (i3 < this.method.getStateCount()) {
            nhmm.setPi(i3, 0.0);
            ++i3;
        }
        int o = 0;
        while (o < this.training.getSequenceCount()) {
            int i4 = 0;
            while (i4 < this.method.getStateCount()) {
                nhmm.setPi(i4, nhmm.getPi(i4) + allGamma[o][0][i4] / (double)this.training.getSequenceCount());
                ++i4;
            }
            ++o;
        }
        i3 = 0;
        while (i3 < this.method.getStateCount()) {
            double[] weights = new double[this.training.size()];
            double sum = 0.0;
            int j = 0;
            int o2 = 0;
            for (MLDataSet obsSeq : this.training.getSequences()) {
                int t = 0;
                while (t < obsSeq.size()) {
                    weights[j] = allGamma[o2][t][i3];
                    sum += weights[j];
                    ++t;
                    ++j;
                }
                ++o2;
            }
            --j;
            while (j >= 0) {
                int n = j--;
                weights[n] = weights[n] / sum;
            }
            StateDistribution opdf = nhmm.getStateDistribution(i3);
            opdf.fit(this.training, weights);
            ++i3;
        }
        this.method = nhmm;
    }

    @Override
    public void iteration(int count) {
        int i = 0;
        while (i < count) {
            this.iteration();
            ++i;
        }
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public void setError(double error) {
    }

    @Override
    public void setIteration(int iteration) {
        this.iterations = iteration;
    }
}

