package cc.mallet.topics;

import cc.mallet.classify.MaxEnt;
import cc.mallet.types.Alphabet;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

/* loaded from: input_file:cc/mallet/topics/DMRInferencer.class */
public class DMRInferencer extends TopicInferencer implements Serializable {
    protected MaxEnt dmrParameters;
    protected int numFeatures;
    protected int defaultFeatureIndex;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 0;
    private static final int NULL_INTEGER = -1;

    public DMRInferencer(int[][] iArr, int[] iArr2, MaxEnt maxEnt, Alphabet alphabet, double d, double d2) {
        this.dmrParameters = null;
        this.dmrParameters = maxEnt;
        this.numFeatures = maxEnt.getAlphabet().size();
        this.defaultFeatureIndex = maxEnt.getDefaultFeatureIndex();
        this.tokensPerTopic = iArr2;
        this.typeTopicCounts = iArr;
        this.alphabet = alphabet;
        this.numTopics = iArr2.length;
        this.numTypes = iArr.length;
        if (Integer.bitCount(this.numTopics) == 1) {
            this.topicMask = this.numTopics - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        } else {
            this.topicMask = (Integer.highestOneBit(this.numTopics) * 2) - 1;
            this.topicBits = Integer.bitCount(this.topicMask);
        }
        this.beta = d;
        this.betaSum = d2;
        this.cachedCoefficients = new double[this.numTopics];
        this.alpha = new double[this.numTopics];
        for (int i = 0; i < this.numTopics; i++) {
            this.smoothingOnlyMass += (this.alpha[i] * d) / (iArr2[i] + d2);
            this.cachedCoefficients[i] = this.alpha[i] / (iArr2[i] + d2);
        }
        this.random = new Randoms();
    }

    @Override // cc.mallet.topics.TopicInferencer
    public double[] getSampledDistribution(Instance instance, int i, int i2, int i3) {
        FeatureVector featureVector = (FeatureVector) instance.getTarget();
        double[] parameters = this.dmrParameters.getParameters();
        for (int i4 = 0; i4 < this.numTopics; i4++) {
            this.alpha[i4] = parameters[(i4 * this.numFeatures) + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(parameters, this.numFeatures, i4, featureVector, this.defaultFeatureIndex, null);
            this.alpha[i4] = Math.exp(this.alpha[i4]);
            this.cachedCoefficients[i4] = this.alpha[i4] / (this.tokensPerTopic[i4] + this.betaSum);
        }
        return super.getSampledDistribution(instance, i, i2, i3);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        System.out.println("writing");
        objectOutputStream.writeInt(0);
        objectOutputStream.writeObject(this.dmrParameters);
        objectOutputStream.writeObject(this.alphabet);
        objectOutputStream.writeInt(this.numTopics);
        objectOutputStream.writeInt(this.topicMask);
        objectOutputStream.writeInt(this.topicBits);
        objectOutputStream.writeInt(this.numTypes);
        objectOutputStream.writeObject(this.alpha);
        objectOutputStream.writeDouble(this.beta);
        objectOutputStream.writeDouble(this.betaSum);
        objectOutputStream.writeObject(this.typeTopicCounts);
        objectOutputStream.writeObject(this.tokensPerTopic);
        objectOutputStream.writeObject(this.random);
        objectOutputStream.writeDouble(this.smoothingOnlyMass);
        objectOutputStream.writeObject(this.cachedCoefficients);
        System.out.println("done");
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        objectInputStream.readInt();
        this.dmrParameters = (MaxEnt) objectInputStream.readObject();
        this.numFeatures = this.dmrParameters.getAlphabet().size();
        this.defaultFeatureIndex = this.dmrParameters.getDefaultFeatureIndex();
        this.alphabet = (Alphabet) objectInputStream.readObject();
        this.numTopics = objectInputStream.readInt();
        this.topicMask = objectInputStream.readInt();
        this.topicBits = objectInputStream.readInt();
        this.numTypes = objectInputStream.readInt();
        this.alpha = (double[]) objectInputStream.readObject();
        this.beta = objectInputStream.readDouble();
        this.betaSum = objectInputStream.readDouble();
        this.typeTopicCounts = (int[][]) objectInputStream.readObject();
        this.tokensPerTopic = (int[]) objectInputStream.readObject();
        this.random = (Randoms) objectInputStream.readObject();
        this.smoothingOnlyMass = objectInputStream.readDouble();
        this.cachedCoefficients = (double[]) objectInputStream.readObject();
    }

    public static DMRInferencer read(File file) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(file));
        DMRInferencer dMRInferencer = (DMRInferencer) objectInputStream.readObject();
        objectInputStream.close();
        return dMRInferencer;
    }
}
