package cc.mallet.topics;

import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import java.util.Random;

/* loaded from: input_file:cc/mallet/topics/WordEmbeddingRunnable.class */
public class WordEmbeddingRunnable implements Runnable {
    public WordEmbeddings model;
    public InstanceList instances;
    public int numSamples;
    int numThreads;
    int threadID;
    int stride;
    int docID;
    int numColumns;
    private int minDocumentLength;
    static final /* synthetic */ boolean $assertionsDisabled;
    public boolean shouldRun = true;
    double residual = 0.0d;
    int numUpdates = 0;
    int iteration = 0;
    int orderingStrategy = 0;
    public long wordsSoFar = 0;
    public Random random = new Random();

    public WordEmbeddingRunnable(WordEmbeddings wordEmbeddings, InstanceList instanceList, int i, int i2, int i3) {
        this.model = wordEmbeddings;
        this.stride = wordEmbeddings.stride;
        this.instances = instanceList;
        this.numSamples = i;
        this.numThreads = i2;
        this.threadID = i3;
        this.numColumns = wordEmbeddings.numColumns;
        this.minDocumentLength = wordEmbeddings.getMinDocumentLength();
    }

    public void setRandomSeed(int i) {
        this.random = new Random(i);
    }

    public void setOrdering(int i) {
        this.orderingStrategy = i;
    }

    public double getMeanError() {
        if (this.numUpdates == 0) {
            return this.docID;
        }
        double d = this.residual / this.numUpdates;
        this.residual = 0.0d;
        this.numUpdates = 0;
        return d;
    }

    @Override // java.lang.Runnable
    public void run() {
        long j = 0;
        long j2 = 0;
        long j3 = 0;
        int size = this.instances.size();
        double d = 0.025d;
        double[] dArr = new double[this.numColumns];
        int i = this.threadID * (size / this.numThreads);
        int i2 = ((this.threadID + 1) * (size / this.numThreads)) - i;
        int[] iArr = new int[i2];
        if (this.orderingStrategy == 1) {
            for (int i3 = 0; i3 < i2; i3++) {
                iArr[i3] = i3;
            }
            for (int i4 = 0; i4 < i2; i4++) {
                int nextInt = i4 + this.random.nextInt(i2 - i4);
                int i5 = iArr[nextInt];
                iArr[nextInt] = iArr[i4];
                iArr[i4] = i5;
            }
        } else if (this.orderingStrategy == 2) {
            for (int i6 = 0; i6 < i2; i6++) {
                iArr[i6] = i + this.random.nextInt(i2);
            }
        } else {
            for (int i7 = 0; i7 < i2; i7++) {
                iArr[i7] = i + i7;
            }
        }
        this.docID = 0;
        if ((this.threadID + 1) * (size / this.numThreads) > size) {
        }
        double d2 = this.model.sigmoidCacheSize / (this.model.maxExpValue - this.model.minExpValue);
        int[] iArr2 = new int[100000];
        while (this.shouldRun) {
            Instance instance = this.instances.get(iArr[this.docID]);
            this.docID++;
            if (this.docID == i2) {
                this.docID = 0;
                this.iteration++;
                if (this.iteration >= this.model.numIterations) {
                    this.shouldRun = false;
                    return;
                }
            }
            if (this.wordsSoFar - j > 10000) {
                d = Math.max(2.5E-6d, 0.025d * (1.0d - ((this.numThreads * this.wordsSoFar) / (this.model.numIterations * this.model.totalWords))));
                j = this.wordsSoFar;
            }
            double[] dArr2 = this.model.weights;
            double[] dArr3 = this.model.negativeWeights;
            FeatureSequence featureSequence = (FeatureSequence) instance.getData();
            int length = featureSequence.getLength();
            int i8 = 0;
            for (int i9 = 0; i9 < length; i9++) {
                int indexAtPosition = featureSequence.getIndexAtPosition(i9);
                this.wordsSoFar++;
                if (this.random.nextDouble() < this.model.retentionProbability[indexAtPosition]) {
                    iArr2[i8] = indexAtPosition;
                    i8++;
                    j3++;
                }
            }
            if (!$assertionsDisabled && this.minDocumentLength <= 0) {
                throw new AssertionError();
            }
            if (i8 >= this.minDocumentLength) {
                for (int i10 = 0; i10 < i8; i10++) {
                    j2++;
                    int i11 = iArr2[i10];
                    int i12 = i11 * this.stride;
                    int nextInt2 = this.random.nextInt(this.model.windowSize) + 1;
                    int max = Math.max(0, i10 - nextInt2);
                    int min = Math.min(i8 - 1, i10 + nextInt2);
                    for (int i13 = max; i13 <= min; i13++) {
                        if (i10 != i13) {
                            int i14 = iArr2[i13] * this.stride;
                            double d3 = 0.0d;
                            for (int i15 = 0; i15 < this.numColumns; i15++) {
                                d3 += dArr3[i12 + i15] * dArr2[i14 + i15];
                            }
                            double d4 = d3 < this.model.minExpValue ? d : d3 > this.model.maxExpValue ? 0.0d : d * (1.0d - this.model.sigmoidCache[(int) Math.floor((d3 - this.model.minExpValue) * d2)]);
                            for (int i16 = 0; i16 < this.numColumns; i16++) {
                                dArr[i16] = d4 * dArr3[i12 + i16];
                                int i17 = i12 + i16;
                                dArr3[i17] = dArr3[i17] + (d4 * dArr2[i14 + i16]);
                            }
                            for (int i18 = 0; i18 < this.numSamples; i18++) {
                                int i19 = this.model.samplingTable[this.random.nextInt(this.model.samplingTableSize)];
                                if (i19 != i11) {
                                    int i20 = i19 * this.stride;
                                    double d5 = 0.0d;
                                    for (int i21 = 0; i21 < this.numColumns; i21++) {
                                        d5 += dArr3[i20 + i21] * dArr2[i14 + i21];
                                    }
                                    double d6 = d5 < this.model.minExpValue ? 0.0d : d5 > this.model.maxExpValue ? -d : d * (-this.model.sigmoidCache[(int) Math.floor((d5 - this.model.minExpValue) * d2)]);
                                    for (int i22 = 0; i22 < this.numColumns; i22++) {
                                        int i23 = i22;
                                        dArr[i23] = dArr[i23] + (d6 * dArr3[i20 + i22]);
                                        int i24 = i20 + i22;
                                        dArr3[i24] = dArr3[i24] + (d6 * dArr2[i14 + i22]);
                                    }
                                }
                            }
                            this.numUpdates++;
                            for (int i25 = 0; i25 < this.numColumns; i25++) {
                                int i26 = i14 + i25;
                                dArr2[i26] = dArr2[i26] + dArr[i25];
                            }
                        }
                    }
                }
            }
        }
    }

    static {
        $assertionsDisabled = !WordEmbeddingRunnable.class.desiredAssertionStatus();
    }
}
