package cc.mallet.topics;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.IDSorter;
import cc.mallet.types.InstanceList;
import cc.mallet.util.CommandOption;
import cc.mallet.util.FeatureCountTool;
import cc.mallet.util.Randoms;
import java.io.File;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;

/* loaded from: input_file:cc/mallet/topics/NonNegativeMatrixFactorization.class */
public class NonNegativeMatrixFactorization {
    static CommandOption.String inputFile;
    static CommandOption.String outputWordsFile;
    static CommandOption.String outputDocsFile;
    static CommandOption.Integer numDimensions;
    static CommandOption.Integer clusterSize;
    static CommandOption.Boolean useIDFOption;
    static CommandOption.Integer numIterationsOption;
    InstanceList instances;
    int numFactors;
    int numFeatures;
    int numInstances;
    int numIterations;
    boolean idfWeighting;
    double[] featureWeights;
    double[][] featureFactorWeights;
    double[][] instanceFactorWeights;
    double[] featureSums;
    double[] instanceSums;
    Randoms random;
    public static final String[] BARS;
    static final /* synthetic */ boolean $assertionsDisabled;

    public NonNegativeMatrixFactorization(InstanceList instanceList, int i, boolean z) {
        this(instanceList, i, z, new Randoms());
    }

    public NonNegativeMatrixFactorization(InstanceList instanceList, int i, boolean z, Randoms randoms) {
        this.featureWeights = null;
        this.instances = instanceList;
        this.numFactors = i;
        this.idfWeighting = z;
        this.random = randoms;
        this.numFeatures = instanceList.getDataAlphabet().size();
        this.numInstances = instanceList.size();
        this.featureFactorWeights = new double[this.numFeatures][i];
        this.instanceFactorWeights = new double[this.numInstances][i];
        this.featureSums = new double[i];
        this.instanceSums = new double[i];
        if (z) {
            calculateIDFWeights();
        }
        for (int i2 = 0; i2 < this.numFeatures; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                this.featureFactorWeights[i2][i3] = (0.001d * randoms.nextUniform()) / this.numFeatures;
                double[] dArr = this.featureSums;
                int i4 = i3;
                dArr[i4] = dArr[i4] + this.featureFactorWeights[i2][i3];
            }
        }
        for (int i5 = 0; i5 < this.numInstances; i5++) {
            for (int i6 = 0; i6 < i; i6++) {
                this.instanceFactorWeights[i5][i6] = 1.0d / i;
                double[] dArr2 = this.instanceSums;
                int i7 = i6;
                dArr2[i7] = dArr2[i7] + this.instanceFactorWeights[i5][i6];
            }
        }
    }

    public void calculateIDFWeights() {
        this.idfWeighting = true;
        System.out.println("Counting word features");
        FeatureCountTool featureCountTool = new FeatureCountTool(this.instances);
        featureCountTool.count();
        int[] documentFrequencies = featureCountTool.getDocumentFrequencies();
        this.featureWeights = new double[this.numFeatures];
        for (int i = 0; i < this.numFeatures; i++) {
            if (documentFrequencies[i] > 0) {
                this.featureWeights[i] = Math.log(this.numInstances / documentFrequencies[i]);
            }
        }
    }

    public void initialize(int i) {
        for (int i2 = 0; i2 < this.numFactors; i2++) {
            for (int i3 = 0; i3 < i; i3++) {
                FeatureVector featureVector = (FeatureVector) this.instances.get(this.random.nextInt(this.numInstances)).getData();
                for (int i4 = 0; i4 < featureVector.numLocations(); i4++) {
                    int indexAtLocation = featureVector.indexAtLocation(i4);
                    double valueAtLocation = featureVector.valueAtLocation(i4);
                    if (this.idfWeighting) {
                        valueAtLocation *= this.featureWeights[indexAtLocation];
                    }
                    double[] dArr = this.featureFactorWeights[indexAtLocation];
                    int i5 = i2;
                    dArr[i5] = dArr[i5] + (valueAtLocation / i);
                    double[] dArr2 = this.featureSums;
                    int i6 = i2;
                    dArr2[i6] = dArr2[i6] + (valueAtLocation / i);
                }
            }
        }
    }

    public static String getBar(double d, double d2, double d3) {
        if (d > d3) {
            d = d3;
        }
        if (d < d2) {
            d = d2;
        }
        return BARS[(int) Math.round((8.0d * (d - d2)) / (d3 - d2))];
    }

    public static String getBars(double[] dArr, double d, double d2) {
        StringBuilder sb = new StringBuilder();
        for (double d3 : dArr) {
            sb.append(getBar(d3, d, d2));
        }
        return sb.toString();
    }

    public static String getBars(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        double d2 = Double.POSITIVE_INFINITY;
        for (double d3 : dArr) {
            if (d3 > d) {
                d = d3;
            }
            if (d3 < d2) {
                d2 = d3;
            }
        }
        return getBars(dArr, 0.0d, d);
    }

    public double getDivergence() {
        double d = 0.0d;
        for (int i = 0; i < this.numInstances; i++) {
            FeatureVector featureVector = (FeatureVector) this.instances.get(i).getData();
            double[] dArr = this.instanceFactorWeights[i];
            for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                int indexAtLocation = featureVector.indexAtLocation(i2);
                double valueAtLocation = featureVector.valueAtLocation(i2);
                if (this.idfWeighting) {
                    valueAtLocation *= this.featureWeights[indexAtLocation];
                }
                double[] dArr2 = this.featureFactorWeights[indexAtLocation];
                double d2 = 0.0d;
                for (int i3 = 0; i3 < this.numFactors; i3++) {
                    d2 += dArr[i3] * dArr2[i3];
                }
                if (d2 != 0.0d) {
                    d += ((valueAtLocation * Math.log(valueAtLocation / d2)) - valueAtLocation) + d2;
                }
            }
        }
        return d;
    }

    public void updateWeights() {
        for (int i = 0; i < this.numInstances; i++) {
            FeatureVector featureVector = (FeatureVector) this.instances.get(i).getData();
            double[] dArr = this.instanceFactorWeights[i];
            double[] dArr2 = new double[this.numFactors];
            double d = 0.0d;
            for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                int indexAtLocation = featureVector.indexAtLocation(i2);
                double valueAtLocation = featureVector.valueAtLocation(i2);
                if (this.idfWeighting) {
                    valueAtLocation *= this.featureWeights[indexAtLocation];
                }
                d += valueAtLocation;
                double[] dArr3 = this.featureFactorWeights[indexAtLocation];
                double d2 = 0.0d;
                for (int i3 = 0; i3 < this.numFactors; i3++) {
                    d2 += dArr[i3] * dArr3[i3];
                }
                if (d2 != 0.0d) {
                    double d3 = valueAtLocation / d2;
                    for (int i4 = 0; i4 < this.numFactors; i4++) {
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] + (dArr3[i4] * d3);
                    }
                }
            }
            if (d > 0.0d) {
                for (int i6 = 0; i6 < this.numFactors; i6++) {
                    int i7 = i6;
                    dArr[i7] = dArr[i7] * (dArr2[i6] / this.featureSums[i6]);
                    if (!$assertionsDisabled && Double.isNaN(dArr[i6])) {
                        throw new AssertionError();
                    }
                }
            } else {
                for (int i8 = 0; i8 < this.numFactors; i8++) {
                    dArr[i8] = 0.0d;
                }
            }
        }
        Arrays.fill(this.instanceSums, 0.0d);
        for (int i9 = 0; i9 < this.numInstances; i9++) {
            for (int i10 = 0; i10 < this.numFactors; i10++) {
                double[] dArr4 = this.instanceSums;
                int i11 = i10;
                dArr4[i11] = dArr4[i11] + this.instanceFactorWeights[i9][i10];
            }
        }
        double[][] dArr5 = new double[this.numFeatures][this.numFactors];
        for (int i12 = 0; i12 < this.numInstances; i12++) {
            FeatureVector featureVector2 = (FeatureVector) this.instances.get(i12).getData();
            double[] dArr6 = this.instanceFactorWeights[i12];
            for (int i13 = 0; i13 < featureVector2.numLocations(); i13++) {
                int indexAtLocation2 = featureVector2.indexAtLocation(i13);
                double valueAtLocation2 = featureVector2.valueAtLocation(i13);
                if (this.idfWeighting) {
                    valueAtLocation2 *= this.featureWeights[indexAtLocation2];
                }
                if (valueAtLocation2 != 0.0d) {
                    double[] dArr7 = this.featureFactorWeights[indexAtLocation2];
                    double d4 = 0.0d;
                    for (int i14 = 0; i14 < this.numFactors; i14++) {
                        if (!$assertionsDisabled && dArr6[i14] < 0.0d) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && dArr7[i14] < 0.0d) {
                            throw new AssertionError();
                        }
                        d4 += dArr6[i14] * dArr7[i14];
                    }
                    double d5 = valueAtLocation2 / d4;
                    for (int i15 = 0; i15 < this.numFactors; i15++) {
                        if (!$assertionsDisabled && Double.isNaN(dArr6[i15])) {
                            throw new AssertionError();
                        }
                        if (!$assertionsDisabled && Double.isNaN(d5)) {
                            throw new AssertionError(valueAtLocation2 + " / " + d4);
                        }
                        double[] dArr8 = dArr5[indexAtLocation2];
                        int i16 = i15;
                        dArr8[i16] = dArr8[i16] + (dArr6[i15] * d5);
                    }
                }
            }
        }
        for (int i17 = 0; i17 < this.numFeatures; i17++) {
            double[] dArr9 = this.featureFactorWeights[i17];
            for (int i18 = 0; i18 < this.numFactors; i18++) {
                if (!$assertionsDisabled && Double.isNaN(dArr5[i17][i18])) {
                    throw new AssertionError();
                }
                if (!$assertionsDisabled && Double.isNaN(this.instanceSums[i18])) {
                    throw new AssertionError(this.instanceSums[i18]);
                }
                int i19 = i18;
                dArr9[i19] = dArr9[i19] * (dArr5[i17][i18] / this.instanceSums[i18]);
                if (!$assertionsDisabled && Double.isNaN(dArr9[i18])) {
                    throw new AssertionError();
                }
            }
        }
        Arrays.fill(this.featureSums, 0.0d);
        for (int i20 = 0; i20 < this.numFeatures; i20++) {
            for (int i21 = 0; i21 < this.numFactors; i21++) {
                double[] dArr10 = this.featureSums;
                int i22 = i21;
                dArr10[i22] = dArr10[i22] + this.featureFactorWeights[i20][i21];
            }
        }
    }

    public void printFactorFeatures(int i) {
        IDSorter[] iDSorterArr = new IDSorter[this.numFeatures];
        StringBuilder sb = new StringBuilder();
        for (int i2 = 0; i2 < this.numFactors; i2++) {
            for (int i3 = 0; i3 < this.numFeatures; i3++) {
                iDSorterArr[i3] = new IDSorter(i3, this.featureFactorWeights[i3][i2]);
            }
            Arrays.sort(iDSorterArr);
            sb.append(i2 + "\t");
            for (int i4 = 0; i4 < i; i4++) {
                sb.append(this.instances.getDataAlphabet().lookupObject(iDSorterArr[i4].getID()) + " ");
            }
            sb.append("\n");
        }
        System.out.println(sb);
    }

    public void writeFeatureFactors(PrintWriter printWriter) throws IOException {
        for (int i = 0; i < this.numFeatures; i++) {
            double[] dArr = this.featureFactorWeights[i];
            printWriter.print(this.instances.getDataAlphabet().lookupObject(i));
            for (int i2 = 0; i2 < this.numFactors; i2++) {
                printWriter.format("\t%f", Double.valueOf(dArr[i2]));
            }
            printWriter.println();
        }
    }

    public void writeInstanceFactors(PrintWriter printWriter) throws IOException {
        for (int i = 0; i < this.numInstances; i++) {
            double[] dArr = this.instanceFactorWeights[i];
            printWriter.print(this.instances.get(i).getName());
            for (int i2 = 0; i2 < this.numFactors; i2++) {
                printWriter.format("\t%f", Double.valueOf(dArr[i2]));
            }
            printWriter.println();
        }
    }

    public static void main(String[] strArr) throws Exception {
        CommandOption.setSummary(NonNegativeMatrixFactorization.class, "Train non-negative matrix factorization.");
        CommandOption.process(NonNegativeMatrixFactorization.class, strArr);
        NonNegativeMatrixFactorization nonNegativeMatrixFactorization = new NonNegativeMatrixFactorization(InstanceList.load(new File(inputFile.value)), numDimensions.value, useIDFOption.value);
        if (clusterSize.value > 0) {
            nonNegativeMatrixFactorization.initialize(clusterSize.value);
        }
        System.out.println("Finding " + numDimensions.value + " factors.");
        System.out.println("Histograms show relative factor sizes, the number measures factorization error (smaller is better).");
        double d = Double.POSITIVE_INFINITY;
        for (int i = 1; i <= numIterationsOption.value; i++) {
            nonNegativeMatrixFactorization.updateWeights();
            if (i % 100 == 0) {
                nonNegativeMatrixFactorization.printFactorFeatures(15);
            }
            if (i % 10 == 0) {
                double divergence = nonNegativeMatrixFactorization.getDivergence();
                System.out.println(getBars(nonNegativeMatrixFactorization.featureSums) + "\t" + getBars(nonNegativeMatrixFactorization.instanceSums) + "\t" + divergence);
                if (divergence / d > 0.9999d) {
                    break;
                } else {
                    d = divergence;
                }
            }
        }
        if (outputWordsFile.value != null) {
            System.out.println("Writing to " + outputWordsFile.value);
            PrintWriter printWriter = new PrintWriter(new File(outputWordsFile.value));
            nonNegativeMatrixFactorization.writeFeatureFactors(printWriter);
            printWriter.close();
        }
        if (outputDocsFile.value != null) {
            System.out.println("Writing to " + outputDocsFile.value);
            PrintWriter printWriter2 = new PrintWriter(new File(outputDocsFile.value));
            nonNegativeMatrixFactorization.writeInstanceFactors(printWriter2);
            printWriter2.close();
        }
    }

    static {
        $assertionsDisabled = !NonNegativeMatrixFactorization.class.desiredAssertionStatus();
        inputFile = new CommandOption.String(NonNegativeMatrixFactorization.class, "input", "FILENAME", true, null, "The filename from which to read the list of training instances.  Use - for stdin.  The instances must be FeatureVectors, not FeatureSequences", null);
        outputWordsFile = new CommandOption.String(NonNegativeMatrixFactorization.class, "output-words", "FILENAME", true, "word-weights.txt", "The filename to write weights for each word.", null);
        outputDocsFile = new CommandOption.String(NonNegativeMatrixFactorization.class, "output-docs", "FILENAME", true, "doc-weights.txt", "The filename to write weights for each document.", null);
        numDimensions = new CommandOption.Integer(NonNegativeMatrixFactorization.class, "num-dimensions", "INTEGER", true, 50, "The number of dimensions to fit.", null);
        clusterSize = new CommandOption.Integer(NonNegativeMatrixFactorization.class, "init-cluster-size", "INTEGER", true, 0, "Select this number of random instances to initialize each dimension. 0 = off.", null);
        useIDFOption = new CommandOption.Boolean(NonNegativeMatrixFactorization.class, "use-idf", "TRUE/FALSE", true, true, "Whether to use IDF weighting.", null);
        numIterationsOption = new CommandOption.Integer(NonNegativeMatrixFactorization.class, "num-iters", "INTEGER", true, 1000, "The number of passes through the training data.", null);
        BARS = new String[]{" ", "▁", "▂", "▃", "▄", "▅", "▆", "▇", "█"};
    }
}
