/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.topics;

import cc.mallet.types.IDSorter;
import cc.mallet.util.Randoms;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.IntObjectHashMap;
import com.carrotsearch.hppc.cursors.IntCursor;
import com.carrotsearch.hppc.cursors.IntIntCursor;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.zip.GZIPInputStream;

public class MultinomialHMM {
    int numTopics;
    int numStates;
    int numDocs;
    int numSequences;
    double[] alpha;
    double alphaSum;
    double beta;
    double betaSum;
    double gamma;
    double gammaSum;
    double pi;
    double sumPi;
    IntObjectHashMap<IntIntHashMap> documentTopics;
    int[] documentSequenceIDs;
    int[] documentStates;
    int[][] stateTopicCounts;
    int[] stateTopicTotals;
    int[][] stateStateTransitions;
    int[] stateTransitionTotals;
    int[] initialStateCounts;
    int[] maxTokensPerTopic;
    int maxDocLength;
    double[][][] topicLogGammaCache;
    double[][] docLogGammaCache;
    int numIterations = 1000;
    int burninPeriod = 200;
    int saveSampleInterval = 10;
    int optimizeInterval = 0;
    int showTopicsInterval = 50;
    String[] topicKeys;
    Randoms random;
    NumberFormat formatter = NumberFormat.getInstance();

    public MultinomialHMM(int numberOfTopics, String topicsFilename, int numStates) throws IOException {
        this.formatter.setMaximumFractionDigits(5);
        System.out.println("LDA HMM: " + numberOfTopics);
        this.documentTopics = new IntObjectHashMap();
        this.numTopics = numberOfTopics;
        this.alphaSum = numberOfTopics;
        this.alpha = new double[numberOfTopics];
        Arrays.fill(this.alpha, this.alphaSum / (double)this.numTopics);
        this.topicKeys = new String[this.numTopics];
        this.loadTopicsFromFile(topicsFilename);
        this.documentStates = new int[this.numDocs];
        this.documentSequenceIDs = new int[this.numDocs];
        this.maxTokensPerTopic = new int[this.numTopics];
        this.maxDocLength = 0;
        for (int doc = 0; doc < this.numDocs; ++doc) {
            if (!this.documentTopics.containsKey(doc)) continue;
            IntIntHashMap topicCounts = (IntIntHashMap)this.documentTopics.get(doc);
            int count = 0;
            for (IntIntCursor keyVal : topicCounts) {
                int topicCount = keyVal.value;
                int topic = keyVal.key;
                if (topicCount > this.maxTokensPerTopic[topic]) {
                    this.maxTokensPerTopic[topic] = topicCount;
                }
                count += topicCount;
            }
            if (count <= this.maxDocLength) continue;
            this.maxDocLength = count;
        }
        this.numStates = numStates;
        this.initialStateCounts = new int[numStates];
        this.topicLogGammaCache = new double[numStates][this.numTopics][];
        for (int state = 0; state < numStates; ++state) {
            for (int topic = 0; topic < this.numTopics; ++topic) {
                this.topicLogGammaCache[state][topic] = new double[this.maxTokensPerTopic[topic] + 1];
            }
        }
        System.out.println(this.maxDocLength);
        this.docLogGammaCache = new double[numStates][this.maxDocLength + 1];
    }

    public void setGamma(double g) {
        this.gamma = g;
    }

    public void setNumIterations(int numIterations) {
        this.numIterations = numIterations;
    }

    public void setBurninPeriod(int burninPeriod) {
        this.burninPeriod = burninPeriod;
    }

    public void setTopicDisplayInterval(int interval) {
        this.showTopicsInterval = interval;
    }

    public void setRandomSeed(int seed) {
        this.random = new Randoms(seed);
    }

    public void setOptimizeInterval(int interval) {
        this.optimizeInterval = interval;
    }

    public void initialize() {
        if (this.random == null) {
            this.random = new Randoms();
        }
        this.gammaSum = this.gamma * (double)this.numStates;
        this.stateTopicCounts = new int[this.numStates][this.numTopics];
        this.stateTopicTotals = new int[this.numStates];
        this.stateStateTransitions = new int[this.numStates][this.numStates];
        this.stateTransitionTotals = new int[this.numStates];
        this.pi = 1000.0;
        this.sumPi = (double)this.numStates * this.pi;
        boolean maxTokens = false;
        boolean totalTokens = false;
        this.numSequences = 0;
        int currentSequenceID = -1;
        IntIntHashMap allTopicsDummy = new IntIntHashMap();
        for (int topic = 0; topic < this.numTopics; ++topic) {
            allTopicsDummy.put(topic, 1);
        }
        for (int state = 0; state < this.numStates; ++state) {
            this.recacheStateTopicDistribution(state, allTopicsDummy);
        }
        for (int doc = 0; doc < this.numDocs; ++doc) {
            this.sampleState(doc, this.random, true);
        }
    }

    private void recacheStateTopicDistribution(int state, IntIntHashMap topicCounts) {
        int[] currentStateTopicCounts = this.stateTopicCounts[state];
        double[][] currentStateCache = this.topicLogGammaCache[state];
        for (IntCursor cursor : topicCounts.keys()) {
            int topic = cursor.value;
            double[] cache = currentStateCache[topic];
            cache[0] = 0.0;
            for (int i = 1; i < cache.length; ++i) {
                cache[i] = cache[i - 1] + Math.log(this.alpha[topic] + (double)i - 1.0 + (double)currentStateTopicCounts[topic]);
            }
        }
        this.docLogGammaCache[state][0] = 0.0;
        for (int i = 1; i < this.docLogGammaCache[state].length; ++i) {
            this.docLogGammaCache[state][i] = this.docLogGammaCache[state][i - 1] + Math.log(this.alphaSum + (double)i - 1.0 + (double)this.stateTopicTotals[state]);
        }
    }

    public void sample() throws IOException {
        long startTime = System.currentTimeMillis();
        for (int iterations = 1; iterations <= this.numIterations; ++iterations) {
            long iterationStart = System.currentTimeMillis();
            for (int doc = 0; doc < this.numDocs; ++doc) {
                this.sampleState(doc, this.random, false);
            }
            System.out.print(System.currentTimeMillis() - iterationStart + " ");
            if (iterations % 10 == 0) {
                System.out.println("<" + iterations + "> ");
                PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("state_state_matrix." + iterations)));
                out.print(this.stateTransitionMatrix());
                out.close();
                out = new PrintWriter(new BufferedWriter(new FileWriter("state_topics." + iterations)));
                out.print(this.stateTopics());
                out.close();
                if (iterations % 10 == 0) {
                    out = new PrintWriter(new BufferedWriter(new FileWriter("states." + iterations)));
                    for (int doc = 0; doc < this.documentStates.length; ++doc) {
                        out.println(this.documentStates[doc]);
                    }
                    out.close();
                }
            }
            System.out.flush();
        }
        long seconds = Math.round((double)(System.currentTimeMillis() - startTime) / 1000.0);
        long minutes = seconds / 60L;
        seconds %= 60L;
        long hours = minutes / 60L;
        minutes %= 60L;
        long days = hours / 24L;
        hours %= 24L;
        System.out.print("\nTotal time: ");
        if (days != 0L) {
            System.out.print(days);
            System.out.print(" days ");
        }
        if (hours != 0L) {
            System.out.print(hours);
            System.out.print(" hours ");
        }
        if (minutes != 0L) {
            System.out.print(minutes);
            System.out.print(" minutes ");
        }
        System.out.print(seconds);
        System.out.println(" seconds");
    }

    public void loadTopicsFromFile(String stateFilename) throws IOException {
        BufferedReader in = stateFilename.endsWith(".gz") ? new BufferedReader(new InputStreamReader(new GZIPInputStream(new FileInputStream(stateFilename)))) : new BufferedReader(new FileReader(new File(stateFilename)));
        this.numDocs = 0;
        String line = null;
        while ((line = in.readLine()) != null) {
            if (line.startsWith("#")) continue;
            String[] fields = line.split(" ");
            int doc = Integer.parseInt(fields[0]);
            int token = Integer.parseInt(fields[1]);
            int type = Integer.parseInt(fields[2]);
            int topic = Integer.parseInt(fields[4]);
            if (!this.documentTopics.containsKey(doc)) {
                this.documentTopics.put(doc, (Object)new IntIntHashMap());
            }
            if (((IntIntHashMap)this.documentTopics.get(doc)).containsKey(topic)) {
                ((IntIntHashMap)this.documentTopics.get(doc)).addTo(topic, 1);
            } else {
                ((IntIntHashMap)this.documentTopics.get(doc)).put(topic, 1);
            }
            if (doc < this.numDocs) continue;
            this.numDocs = doc + 1;
        }
        in.close();
        System.out.println("loaded topics, " + this.numDocs + " documents");
    }

    public void loadAlphaFromFile(String alphaFilename) throws IOException {
        this.alphaSum = 0.0;
        BufferedReader in = new BufferedReader(new FileReader(new File(alphaFilename)));
        String line = null;
        while ((line = in.readLine()) != null) {
            if (line.equals("")) continue;
            String[] fields = line.split("\\s+");
            int topic = Integer.parseInt(fields[0]);
            this.alpha[topic] = 1.0;
            this.alphaSum += this.alpha[topic];
            StringBuffer topicKey = new StringBuffer();
            for (int i = 2; i < fields.length; ++i) {
                topicKey.append(fields[i] + " ");
            }
            this.topicKeys[topic] = topicKey.toString();
        }
        in.close();
        System.out.println("loaded alpha");
    }

    public void loadSequenceIDsFromFile(String sequenceFilename) throws IOException {
        int doc = 0;
        int currentSequenceID = -1;
        BufferedReader in = new BufferedReader(new FileReader(new File(sequenceFilename)));
        String line = null;
        while ((line = in.readLine()) != null) {
            int sequenceID;
            String[] fields = line.split("\\t");
            this.documentSequenceIDs[doc] = sequenceID = Integer.parseInt(fields[0]);
            if (sequenceID != currentSequenceID) {
                ++this.numSequences;
            }
            currentSequenceID = sequenceID;
            ++doc;
        }
        in.close();
        if (doc != this.numDocs) {
            System.out.println("Warning: number of documents with topics (" + this.numDocs + ") is not equal to number of docs with sequence IDs (" + doc + ")");
        }
        System.out.println("loaded sequence");
    }

    private void sampleState(int doc, Randoms r, boolean initializing) {
        int newState;
        int nextState;
        int previousState;
        int state;
        long startTime = System.currentTimeMillis();
        if (!this.documentTopics.containsKey(doc)) {
            return;
        }
        IntIntHashMap topicCounts = (IntIntHashMap)this.documentTopics.get(doc);
        int oldState = this.documentStates[doc];
        int[] currentStateTopicCounts = this.stateTopicCounts[oldState];
        int docLength = 0;
        for (IntIntCursor keyVal : topicCounts) {
            int topic = keyVal.key;
            int topicCount = keyVal.value;
            if (!initializing) {
                int n = topic;
                currentStateTopicCounts[n] = currentStateTopicCounts[n] - topicCount;
            }
            docLength += topicCount;
        }
        if (!initializing) {
            int n = oldState;
            this.stateTopicTotals[n] = this.stateTopicTotals[n] - docLength;
            this.recacheStateTopicDistribution(oldState, topicCounts);
        }
        int previousSequenceID = -1;
        if (doc > 0) {
            previousSequenceID = this.documentSequenceIDs[doc - 1];
        }
        int sequenceID = this.documentSequenceIDs[doc];
        int nextSequenceID = -1;
        if (!initializing && doc < this.numDocs - 1) {
            nextSequenceID = this.documentSequenceIDs[doc + 1];
        }
        double[] stateLogLikelihoods = new double[this.numStates];
        double[] samplingDistribution = new double[this.numStates];
        if (initializing) {
            if (previousSequenceID != sequenceID) {
                for (state = 0; state < this.numStates; ++state) {
                    stateLogLikelihoods[state] = Math.log(((double)this.initialStateCounts[state] + this.pi) / ((double)(this.numSequences - 1) + this.sumPi));
                }
            } else {
                previousState = this.documentStates[doc - 1];
                for (state = 0; state < this.numStates; ++state) {
                    stateLogLikelihoods[state] = Math.log((double)this.stateStateTransitions[previousState][state] + this.gamma);
                    if (!Double.isInfinite(stateLogLikelihoods[state])) continue;
                    System.out.println("infinite end");
                }
            }
        } else if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
            int n = oldState;
            this.initialStateCounts[n] = this.initialStateCounts[n] - 1;
            for (state = 0; state < this.numStates; ++state) {
                stateLogLikelihoods[state] = Math.log(((double)this.initialStateCounts[state] + this.pi) / ((double)(this.numSequences - 1) + this.sumPi));
            }
        } else if (previousSequenceID != sequenceID) {
            int n = oldState;
            this.initialStateCounts[n] = this.initialStateCounts[n] - 1;
            nextState = this.documentStates[doc + 1];
            int[] nArray = this.stateStateTransitions[oldState];
            int n2 = nextState;
            nArray[n2] = nArray[n2] - 1;
            assert (this.stateStateTransitions[oldState][nextState] >= 0);
            int n3 = oldState;
            this.stateTransitionTotals[n3] = this.stateTransitionTotals[n3] - 1;
            for (state = 0; state < this.numStates; ++state) {
                stateLogLikelihoods[state] = Math.log(((double)this.stateStateTransitions[state][nextState] + this.gamma) * ((double)this.initialStateCounts[state] + this.pi) / ((double)(this.numSequences - 1) + this.sumPi));
                if (!Double.isInfinite(stateLogLikelihoods[state])) continue;
                System.out.println("infinite beginning");
            }
        } else if (sequenceID != nextSequenceID) {
            previousState = this.documentStates[doc - 1];
            int[] nArray = this.stateStateTransitions[previousState];
            int n = oldState;
            nArray[n] = nArray[n] - 1;
            assert (this.stateStateTransitions[previousState][oldState] >= 0);
            for (state = 0; state < this.numStates; ++state) {
                stateLogLikelihoods[state] = Math.log((double)this.stateStateTransitions[previousState][state] + this.gamma);
                if (!Double.isInfinite(stateLogLikelihoods[state])) continue;
                System.out.println("infinite end");
            }
        } else {
            nextState = this.documentStates[doc + 1];
            int[] nArray = this.stateStateTransitions[oldState];
            int n = nextState;
            nArray[n] = nArray[n] - 1;
            if (this.stateStateTransitions[oldState][nextState] < 0) {
                System.out.println(this.printStateTransitions());
                System.out.println(oldState + " -> " + nextState);
                System.out.println(sequenceID);
            }
            assert (this.stateStateTransitions[oldState][nextState] >= 0);
            int n4 = oldState;
            this.stateTransitionTotals[n4] = this.stateTransitionTotals[n4] - 1;
            previousState = this.documentStates[doc - 1];
            int[] nArray2 = this.stateStateTransitions[previousState];
            int n5 = oldState;
            nArray2[n5] = nArray2[n5] - 1;
            assert (this.stateStateTransitions[previousState][oldState] >= 0);
            for (state = 0; state < this.numStates; ++state) {
                stateLogLikelihoods[state] = previousState == state && state == nextState ? Math.log(((double)this.stateStateTransitions[previousState][state] + this.gamma) * ((double)(this.stateStateTransitions[state][nextState] + 1) + this.gamma) / ((double)(this.stateTransitionTotals[state] + 1) + this.gammaSum)) : (previousState == state ? Math.log(((double)this.stateStateTransitions[previousState][state] + this.gamma) * ((double)this.stateStateTransitions[state][nextState] + this.gamma) / ((double)(this.stateTransitionTotals[state] + 1) + this.gammaSum)) : Math.log(((double)this.stateStateTransitions[previousState][state] + this.gamma) * ((double)this.stateStateTransitions[state][nextState] + this.gamma) / ((double)this.stateTransitionTotals[state] + this.gammaSum)));
                if (!Double.isInfinite(stateLogLikelihoods[state])) continue;
                System.out.println("infinite middle: " + doc);
                System.out.println(previousState + " -> " + state + " -> " + nextState);
                System.out.println(this.stateStateTransitions[previousState][state] + " -> " + this.stateStateTransitions[state][nextState] + " / " + this.stateTransitionTotals[state]);
            }
        }
        double max = Double.NEGATIVE_INFINITY;
        for (int state2 = 0; state2 < this.numStates; ++state2) {
            int n = state2;
            stateLogLikelihoods[n] = stateLogLikelihoods[n] - (double)(this.stateTransitionTotals[state2] / 10);
            currentStateTopicCounts = this.stateTopicCounts[state2];
            double[][] currentStateLogGammaCache = this.topicLogGammaCache[state2];
            boolean totalTokens = false;
            for (IntIntCursor keyVal : topicCounts) {
                int topic = keyVal.key;
                int count = keyVal.value;
                int n6 = state2;
                stateLogLikelihoods[n6] = stateLogLikelihoods[n6] + currentStateLogGammaCache[topic][count];
            }
            int n7 = state2;
            stateLogLikelihoods[n7] = stateLogLikelihoods[n7] - this.docLogGammaCache[state2][docLength];
            if (!(stateLogLikelihoods[state2] > max)) continue;
            max = stateLogLikelihoods[state2];
        }
        double sum = 0.0;
        for (int state3 = 0; state3 < this.numStates; ++state3) {
            if (Double.isNaN(samplingDistribution[state3])) {
                System.out.println(stateLogLikelihoods[state3]);
            }
            assert (!Double.isNaN(samplingDistribution[state3]));
            samplingDistribution[state3] = Math.exp(stateLogLikelihoods[state3] - max);
            sum += samplingDistribution[state3];
            if (Double.isNaN(samplingDistribution[state3])) {
                System.out.println(stateLogLikelihoods[state3]);
            }
            assert (!Double.isNaN(samplingDistribution[state3]));
            if (doc % 100 != 0) continue;
        }
        this.documentStates[doc] = newState = r.nextDiscrete(samplingDistribution, sum);
        for (int topic = 0; topic < this.numTopics; ++topic) {
            int[] nArray = this.stateTopicCounts[newState];
            int n = topic;
            nArray[n] = nArray[n] + topicCounts.get(topic);
        }
        int n = newState;
        this.stateTopicTotals[n] = this.stateTopicTotals[n] + docLength;
        this.recacheStateTopicDistribution(newState, topicCounts);
        if (initializing) {
            if (previousSequenceID != sequenceID) {
                int n8 = newState;
                this.initialStateCounts[n8] = this.initialStateCounts[n8] + 1;
            } else {
                previousState = this.documentStates[doc - 1];
                int[] nArray = this.stateStateTransitions[previousState];
                int n9 = newState;
                nArray[n9] = nArray[n9] + 1;
                int n10 = newState;
                this.stateTransitionTotals[n10] = this.stateTransitionTotals[n10] + 1;
            }
        } else if (previousSequenceID != sequenceID && sequenceID != nextSequenceID) {
            int n11 = newState;
            this.initialStateCounts[n11] = this.initialStateCounts[n11] + 1;
        } else if (previousSequenceID != sequenceID) {
            int n12 = newState;
            this.initialStateCounts[n12] = this.initialStateCounts[n12] + 1;
            nextState = this.documentStates[doc + 1];
            int[] nArray = this.stateStateTransitions[newState];
            int n13 = nextState;
            nArray[n13] = nArray[n13] + 1;
            int n14 = newState;
            this.stateTransitionTotals[n14] = this.stateTransitionTotals[n14] + 1;
        } else if (sequenceID != nextSequenceID) {
            previousState = this.documentStates[doc - 1];
            int[] nArray = this.stateStateTransitions[previousState];
            int n15 = newState;
            nArray[n15] = nArray[n15] + 1;
        } else {
            previousState = this.documentStates[doc - 1];
            int[] nArray = this.stateStateTransitions[previousState];
            int n16 = newState;
            nArray[n16] = nArray[n16] + 1;
            nextState = this.documentStates[doc + 1];
            int[] nArray3 = this.stateStateTransitions[newState];
            int n17 = nextState;
            nArray3[n17] = nArray3[n17] + 1;
            int n18 = newState;
            this.stateTransitionTotals[n18] = this.stateTransitionTotals[n18] + 1;
        }
    }

    public String printStateTransitions() {
        StringBuffer out = new StringBuffer();
        Object[] sortedTopics = new IDSorter[this.numTopics];
        for (int s = 0; s < this.numStates; ++s) {
            for (int topic = 0; topic < this.numTopics; ++topic) {
                sortedTopics[topic] = new IDSorter(topic, (double)this.stateTopicCounts[s][topic] / (double)this.stateTopicTotals[s]);
            }
            Arrays.sort(sortedTopics);
            out.append("\n" + s + "\n");
            for (int i = 0; i < 4; ++i) {
                int topic = ((IDSorter)sortedTopics[i]).getID();
                out.append(this.stateTopicCounts[s][topic] + "\t" + this.topicKeys[topic] + "\n");
            }
            out.append("\n");
            out.append("[" + this.initialStateCounts[s] + "/" + this.numSequences + "] ");
            out.append("[" + this.stateTransitionTotals[s] + "]");
            for (int t = 0; t < this.numStates; ++t) {
                out.append("\t");
                if (s == t) {
                    out.append("[" + this.stateStateTransitions[s][t] + "]");
                    continue;
                }
                out.append(this.stateStateTransitions[s][t]);
            }
            out.append("\n");
        }
        return out.toString();
    }

    public String stateTransitionMatrix() {
        StringBuffer out = new StringBuffer();
        for (int s = 0; s < this.numStates; ++s) {
            for (int t = 0; t < this.numStates; ++t) {
                out.append(this.stateStateTransitions[s][t]);
                out.append("\t");
            }
            out.append("\n");
        }
        return out.toString();
    }

    public String stateTopics() {
        StringBuffer out = new StringBuffer();
        for (int s = 0; s < this.numStates; ++s) {
            for (int topic = 0; topic < this.numTopics; ++topic) {
                out.append(this.stateTopicCounts[s][topic] + "\t");
            }
            out.append("\n");
        }
        return out.toString();
    }

    public static void main(String[] args) throws IOException {
        if (args.length != 4) {
            System.err.println("Usage: MultinomialHMM [num topics] [lda state file] [lda keys file] [sequence metadata file]");
            System.exit(0);
        }
        int numTopics = Integer.parseInt(args[0]);
        MultinomialHMM hmm = new MultinomialHMM(numTopics, args[1], 150);
        hmm.setGamma(1.0);
        hmm.setRandomSeed(1);
        hmm.loadAlphaFromFile(args[2]);
        hmm.loadSequenceIDsFromFile(args[3]);
        hmm.initialize();
        hmm.sample();
    }
}

