package jsat.text.topicmodel;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.ScaledVector;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.math.FastMath;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.IntList;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;
import jsat.utils.random.RandomUtil;

/* loaded from: input_file:jsat/text/topicmodel/OnlineLDAsvi.class */
public class OnlineLDAsvi implements Parameterized {
    private double alpha;
    private double eta;
    private double tau0;
    private double kappa;
    private int epochs;
    private int D;
    private int K;
    private int W;
    private int miniBatchSize;
    private int t;
    private List<Vec> lambda;
    private List<Lock> lambdaLocks;
    private DoubleList lambdaSums;
    private int[] lastUsed;
    private List<Vec> ELogBeta;
    private List<Vec> ExpELogBeta;
    private ThreadLocal<Vec> gammaLocal;
    private ThreadLocal<Vec> logThetaLocal;
    private ThreadLocal<Vec> expLogThetaLocal;

    public OnlineLDAsvi() {
        this.alpha = 1.0d;
        this.eta = 1.0d;
        this.tau0 = 128.0d;
        this.kappa = 0.7d;
        this.epochs = 1;
        this.D = -1;
        this.K = -1;
        this.W = -1;
        this.miniBatchSize = 256;
        this.W = -1;
        this.D = -1;
        this.K = -1;
    }

    public OnlineLDAsvi(int i, int i2, int i3) {
        this.alpha = 1.0d;
        this.eta = 1.0d;
        this.tau0 = 128.0d;
        this.kappa = 0.7d;
        this.epochs = 1;
        this.D = -1;
        this.K = -1;
        this.W = -1;
        this.miniBatchSize = 256;
        setK(i);
        setD(i2);
        setVocabSize(i3);
    }

    public void setK(final int i) {
        if (i < 2) {
            throw new IllegalArgumentException("At least 2 topics must be learned");
        }
        this.K = i;
        this.gammaLocal = new ThreadLocal<Vec>() { // from class: jsat.text.topicmodel.OnlineLDAsvi.1
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public Vec initialValue() {
                return new DenseVector(i);
            }
        };
        this.logThetaLocal = new ThreadLocal<Vec>() { // from class: jsat.text.topicmodel.OnlineLDAsvi.2
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public Vec initialValue() {
                return new DenseVector(i);
            }
        };
        this.expLogThetaLocal = new ThreadLocal<Vec>() { // from class: jsat.text.topicmodel.OnlineLDAsvi.3
            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.lang.ThreadLocal
            public Vec initialValue() {
                return new DenseVector(i);
            }
        };
        this.lambda = null;
    }

    public int getK() {
        return this.K;
    }

    public void setD(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("The number of documents must be positive, not " + i);
        }
        this.D = i;
    }

    public int getD() {
        return this.D;
    }

    public void setVocabSize(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Vocabulary size must be positive, not " + i);
        }
        this.W = i;
    }

    public int getVocabSize() {
        return this.W;
    }

    public void setAlpha(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Alpha must be a positive constant, not " + d);
        }
        this.alpha = d;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setEta(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Eta must be a positive constant, not " + d);
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    public void setTau0(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Eta must be a positive constant, not " + d);
        }
        this.tau0 = d;
    }

    public void setEpochs(int i) {
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setKappa(double d) {
        if (d < 0.5d || d > 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Kapp must be in [0.5, 1], not " + d);
        }
        this.kappa = d;
    }

    public double getKappa() {
        return this.kappa;
    }

    public void setMiniBatchSize(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("the batch size must be a positive constant, not " + i);
        }
        this.miniBatchSize = i;
    }

    public Vec getTopicVec(int i) {
        return new ScaledVector(1.0d / this.lambda.get(i).sum(), this.lambda.get(i));
    }

    private void expandPsiMinusPsiSum(Vec vec, double d, Vec vec2) {
        double digamma = FastMath.digamma(d);
        for (int i = 0; i < vec.length(); i++) {
            vec2.set(i, FastMath.digamma(vec.get(i)) - digamma);
        }
    }

    private static double sampleExpoDist(double d, double d2) {
        return (-d) * FastMath.log(1.0d - d2);
    }

    public void update(List<Vec> list) {
        update(list, new FakeExecutor());
    }

    public void update(final List<Vec> list, ExecutorService executorService) {
        if (this.lambda == null) {
            initialize();
        }
        updateBetas(list, executorService);
        double d = this.tau0;
        int i = this.t;
        this.t = i + 1;
        final double pow = Math.pow(d + i, -this.kappa);
        for (int i2 = 0; i2 < this.K; i2++) {
            this.lambda.get(i2).mutableMultiply(1.0d - pow);
            this.lambdaSums.set(i2, this.lambdaSums.getD(i2) * (1.0d - pow));
        }
        final int i3 = SystemInfo.LogicalCores;
        final CountDownLatch countDownLatch = new CountDownLatch(i3);
        for (int i4 = 0; i4 < i3; i4++) {
            final int i5 = i4;
            executorService.submit(new Runnable() { // from class: jsat.text.topicmodel.OnlineLDAsvi.4
                @Override // java.lang.Runnable
                public void run() {
                    Random random = RandomUtil.getRandom();
                    for (int startBlock = ParallelUtils.getStartBlock(list.size(), i5, i3); startBlock < ParallelUtils.getEndBlock(list.size(), i5, i3); startBlock++) {
                        Vec vec = (Vec) list.get(startBlock);
                        if (vec.nnz() != 0) {
                            Vec vec2 = (Vec) OnlineLDAsvi.this.logThetaLocal.get();
                            Vec vec3 = (Vec) OnlineLDAsvi.this.expLogThetaLocal.get();
                            Vec vec4 = (Vec) OnlineLDAsvi.this.gammaLocal.get();
                            OnlineLDAsvi.this.prepareGammaTheta(vec4, vec2, vec3, random);
                            int[] iArr = new int[vec.nnz()];
                            double[] dArr = new double[vec.nnz()];
                            OnlineLDAsvi.this.computePhi(vec, iArr, dArr, OnlineLDAsvi.this.K, vec4, vec2, vec3);
                            IntList intList = new IntList(OnlineLDAsvi.this.K);
                            ListUtils.addRange(intList, 0, OnlineLDAsvi.this.K, 1);
                            Collections.shuffle(intList, random);
                            int i6 = 0;
                            while (!intList.isEmpty()) {
                                int i7 = intList.getI(i6);
                                if (((Lock) OnlineLDAsvi.this.lambdaLocks.get(i7)).tryLock()) {
                                    double size = ((vec3.get(i7) * pow) * OnlineLDAsvi.this.D) / list.size();
                                    Vec vec5 = (Vec) OnlineLDAsvi.this.lambda.get(i7);
                                    Vec vec6 = (Vec) OnlineLDAsvi.this.ExpELogBeta.get(i7);
                                    double d2 = OnlineLDAsvi.this.lambdaSums.getD(i7);
                                    for (int i8 = 0; i8 < vec.nnz(); i8++) {
                                        int i9 = iArr[i8];
                                        double d3 = size * dArr[i8] * vec6.get(i9);
                                        vec5.increment(i9, d3);
                                        d2 += d3;
                                    }
                                    OnlineLDAsvi.this.lambdaSums.set(i7, d2);
                                    ((Lock) OnlineLDAsvi.this.lambdaLocks.get(i7)).unlock();
                                    intList.remove(i6);
                                }
                                if (!intList.isEmpty()) {
                                    i6 = (i6 + 1) % intList.size();
                                }
                            }
                        }
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(OnlineLDAsvi.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    public void model(DataSet dataSet, int i) {
        model(dataSet, i, new FakeExecutor());
    }

    public void model(DataSet dataSet, int i, ExecutorService executorService) {
        if (executorService == null) {
            executorService = new FakeExecutor();
        }
        setK(i);
        setD(dataSet.getSampleSize());
        setVocabSize(dataSet.getNumNumericalVars());
        List<Vec> dataVectors = dataSet.getDataVectors();
        for (int i2 = 0; i2 < this.epochs; i2++) {
            Collections.shuffle(dataVectors);
            int i3 = 0;
            while (true) {
                int i4 = i3;
                if (i4 < this.D) {
                    update(dataVectors.subList(i4, Math.min(i4 + this.miniBatchSize, this.D)), executorService);
                    i3 = i4 + this.miniBatchSize;
                }
            }
        }
    }

    public Vec getTopics(Vec vec) {
        DenseVector denseVector = new DenseVector(this.K);
        Random random = RandomUtil.getRandom();
        double d = (this.W * this.K) / (this.D * 100.0d);
        for (int i = 0; i < denseVector.length(); i++) {
            denseVector.set(i, sampleExpoDist(d, random.nextDouble()) + this.eta);
        }
        DenseVector denseVector2 = new DenseVector(this.K);
        DenseVector denseVector3 = new DenseVector(this.K);
        expandPsiMinusPsiSum(denseVector, denseVector.sum(), denseVector2);
        for (int i2 = 0; i2 < denseVector2.length(); i2++) {
            denseVector3.set(i2, FastMath.exp(denseVector2.get(i2)));
        }
        computePhi(vec, new int[vec.nnz()], new double[vec.nnz()], this.K, denseVector, denseVector2, denseVector3);
        denseVector.mutableDivide(denseVector.sum());
        return denseVector;
    }

    private void updateBetas(List<Vec> list, ExecutorService executorService) {
        final double[] dArr = new double[this.K];
        for (int i = 0; i < this.K; i++) {
            dArr[i] = FastMath.digamma((this.W * this.eta) + this.lambdaSums.getD(i));
        }
        List<List> splitList = ListUtils.splitList(list, SystemInfo.LogicalCores);
        final CountDownLatch countDownLatch = new CountDownLatch(splitList.size());
        for (final List list2 : splitList) {
            executorService.submit(new Runnable() { // from class: jsat.text.topicmodel.OnlineLDAsvi.5
                @Override // java.lang.Runnable
                public void run() {
                    Iterator it = list2.iterator();
                    while (it.hasNext()) {
                        Iterator<IndexValue> it2 = ((Vec) it.next()).iterator();
                        while (it2.hasNext()) {
                            int index = it2.next().getIndex();
                            if (OnlineLDAsvi.this.lastUsed[index] != OnlineLDAsvi.this.t) {
                                for (int i2 = 0; i2 < OnlineLDAsvi.this.K; i2++) {
                                    double digamma = FastMath.digamma(OnlineLDAsvi.this.eta + ((Vec) OnlineLDAsvi.this.lambda.get(i2)).get(index)) - dArr[i2];
                                    ((Vec) OnlineLDAsvi.this.ELogBeta.get(i2)).set(index, digamma);
                                    ((Vec) OnlineLDAsvi.this.ExpELogBeta.get(i2)).set(index, FastMath.exp(digamma));
                                }
                                OnlineLDAsvi.this.lastUsed[index] = OnlineLDAsvi.this.t;
                            }
                        }
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(OnlineLDAsvi.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void prepareGammaTheta(Vec vec, Vec vec2, Vec vec3, Random random) {
        double d = (this.W * this.K) / (this.D * 100.0d);
        for (int i = 0; i < vec.length(); i++) {
            vec.set(i, sampleExpoDist(d, random.nextDouble()) + this.eta);
        }
        expandPsiMinusPsiSum(vec, vec.sum(), vec2);
        for (int i2 = 0; i2 < vec2.length(); i2++) {
            vec3.set(i2, FastMath.exp(vec2.get(i2)));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void computePhi(Vec vec, int[] iArr, double[] dArr, int i, Vec vec2, Vec vec3, Vec vec4) {
        int i2 = 0;
        SparseVector sparseVector = new SparseVector(iArr, dArr, this.W, vec.nnz());
        Iterator<IndexValue> it = vec.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            int index = next.getIndex();
            double d = 0.0d;
            for (int i3 = 0; i3 < vec4.length(); i3++) {
                d += vec4.get(i3) * this.ExpELogBeta.get(i3).get(index);
            }
            iArr[i2] = index;
            dArr[i2] = next.getValue() / (d + 1.0E-15d);
            i2++;
        }
        for (int i4 = 0; i4 < 100; i4++) {
            double d2 = 0.0d;
            double d3 = 0.0d;
            for (int i5 = 0; i5 < i; i5++) {
                double d4 = vec2.get(i5);
                double dot = this.alpha + (vec4.get(i5) * sparseVector.dot(this.ExpELogBeta.get(i5)));
                vec2.set(i5, dot);
                d2 += Math.abs(dot - d4);
                d3 += dot;
            }
            expandPsiMinusPsiSum(vec2, d3, vec3);
            for (int i6 = 0; i6 < vec3.length(); i6++) {
                vec4.set(i6, FastMath.exp(vec3.get(i6)));
            }
            int i7 = 0;
            Iterator<IndexValue> it2 = vec.iterator();
            while (it2.hasNext()) {
                IndexValue next2 = it2.next();
                int index2 = next2.getIndex();
                double d5 = 0.0d;
                for (int i8 = 0; i8 < vec4.length(); i8++) {
                    d5 += vec4.get(i8) * this.ExpELogBeta.get(i8).get(index2);
                }
                dArr[i7] = next2.getValue() / (d5 + 1.0E-15d);
                i7++;
            }
            if (d2 < 0.001d * i) {
                return;
            }
        }
    }

    private void initialize() {
        if (this.K < 1) {
            throw new FailedToFitException("Topic number for LDA has not yet been specified");
        }
        if (this.D < 1) {
            throw new FailedToFitException("Expected number of documents has not yet been specified");
        }
        if (this.W < 1) {
            throw new FailedToFitException("Topic vocuabulary size has not yet been specified");
        }
        this.t = 0;
        this.lambda = new ArrayList(this.K);
        this.lambdaLocks = new ArrayList(this.K);
        this.lambdaSums = new DoubleList(this.K);
        this.ELogBeta = new ArrayList(this.K);
        this.ExpELogBeta = new ArrayList(this.K);
        this.lastUsed = new int[this.W];
        Arrays.fill(this.lastUsed, -1);
        double d = (this.K * this.W) / (this.D * 100.0d);
        Random random = RandomUtil.getRandom();
        for (int i = 0; i < this.K; i++) {
            DenseVector denseVector = new DenseVector(this.W);
            this.lambda.add(new ScaledVector(denseVector));
            this.lambdaLocks.add(new ReentrantLock());
            this.ELogBeta.add(new DenseVector(this.W));
            this.ExpELogBeta.add(new DenseVector(this.W));
            double d2 = 0.0d;
            for (int i2 = 0; i2 < this.W; i2++) {
                double sampleExpoDist = sampleExpoDist(d, random.nextDouble()) + this.eta;
                denseVector.set(i2, sampleExpoDist);
                d2 += sampleExpoDist;
            }
            this.lambdaSums.add(d2);
        }
    }
}
