/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.bayesian;

import java.util.Arrays;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.parameters.Parameterized;

public class MultinomialNaiveBayes
extends BaseUpdateableClassifier
implements Parameterized {
    private static final long serialVersionUID = -469977945722725478L;
    private double[][][] apriori;
    private double[][] wordCounts;
    private double[] totalWords;
    private double priorSum = 0.0;
    private double[] priors;
    private double smoothing;
    private boolean finalizeAfterTraining = true;
    private boolean finalized;

    public MultinomialNaiveBayes() {
        this(1.0);
    }

    public MultinomialNaiveBayes(double smoothing) {
        this.setSmoothing(smoothing);
        this.setEpochs(1);
    }

    protected MultinomialNaiveBayes(MultinomialNaiveBayes other) {
        this(other.smoothing);
        if (other.apriori != null) {
            this.apriori = new double[other.apriori.length][][];
            this.wordCounts = new double[other.wordCounts.length][];
            this.totalWords = Arrays.copyOf(other.totalWords, other.totalWords.length);
            this.priors = Arrays.copyOf(other.priors, other.priors.length);
            this.priorSum = other.priorSum;
            for (int c = 0; c < other.apriori.length; ++c) {
                this.apriori[c] = new double[other.apriori[c].length][];
                for (int j = 0; j < other.apriori[c].length; ++j) {
                    this.apriori[c][j] = Arrays.copyOf(other.apriori[c][j], other.apriori[c][j].length);
                }
                this.wordCounts[c] = Arrays.copyOf(other.wordCounts[c], other.wordCounts[c].length);
            }
            this.priorSum = other.priorSum;
            this.priors = Arrays.copyOf(other.priors, other.priors.length);
        }
        this.finalizeAfterTraining = other.finalizeAfterTraining;
        this.finalized = other.finalized;
    }

    public void setSmoothing(double smoothing) {
        if (Double.isNaN(smoothing) || Double.isInfinite(smoothing) || smoothing <= 0.0) {
            throw new IllegalArgumentException("Smoothing constant must be in range (0,Inf), not " + smoothing);
        }
        this.smoothing = smoothing;
    }

    public double getSmoothing() {
        return this.smoothing;
    }

    public void setFinalizeAfterTraining(boolean finalizeAfterTraining) {
        this.finalizeAfterTraining = finalizeAfterTraining;
    }

    public boolean isFinalizeAfterTraining() {
        return this.finalizeAfterTraining;
    }

    @Override
    public MultinomialNaiveBayes clone() {
        return new MultinomialNaiveBayes(this);
    }

    @Override
    public void train(ClassificationDataSet dataSet, boolean parallel) {
        super.train(dataSet, parallel);
        if (this.finalizeAfterTraining) {
            this.finalizeModel();
        }
    }

    @Override
    public void train(ClassificationDataSet dataSet) {
        super.train(dataSet);
        if (this.finalizeAfterTraining) {
            this.finalizeModel();
        }
    }

    public void finalizeModel() {
        if (this.finalized) {
            return;
        }
        double priorSumSmooth = this.priorSum + (double)this.priors.length * this.smoothing;
        for (int c = 0; c < this.priors.length; ++c) {
            double logProb;
            this.priors[c] = logProb = Math.log((this.priors[c] + this.smoothing) / priorSumSmooth);
            double[] counts = this.wordCounts[c];
            double logTotalCounts = Math.log(this.totalWords[c] + this.smoothing * (double)counts.length);
            for (int i = 0; i < counts.length; ++i) {
                counts[i] = Math.log(counts[i] + this.smoothing) - logTotalCounts;
            }
            for (int j = 0; j < this.apriori[c].length; ++j) {
                int z;
                double sum = 0.0;
                for (z = 0; z < this.apriori[c][j].length; ++z) {
                    sum += this.apriori[c][j][z] + this.smoothing;
                }
                for (z = 0; z < this.apriori[c][j].length; ++z) {
                    this.apriori[c][j][z] = Math.log((this.apriori[c][j][z] + this.smoothing) / sum);
                }
            }
        }
        this.finalized = true;
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        int nCat = predicting.getNumOfCategories();
        this.apriori = new double[nCat][categoricalAttributes.length][];
        this.wordCounts = new double[nCat][numericAttributes];
        this.totalWords = new double[nCat];
        this.priors = new double[nCat];
        this.priorSum = 0.0;
        for (int i = 0; i < nCat; ++i) {
            for (int j = 0; j < categoricalAttributes.length; ++j) {
                this.apriori[i][j] = new double[categoricalAttributes[j].getNumOfCategories()];
            }
        }
        this.finalized = false;
    }

    @Override
    public void update(DataPoint dataPoint, int targetClass) {
        if (this.finalized) {
            throw new FailedToFitException("Model has already been finalized, and can no longer be updated");
        }
        double weight = dataPoint.getWeight();
        Vec x = dataPoint.getNumericalValues();
        int[] catValues = dataPoint.getCategoricalValues();
        for (int j = 0; j < this.apriori[targetClass].length; ++j) {
            double[] dArray = this.apriori[targetClass][j];
            int n = catValues[j];
            dArray[n] = dArray[n] + weight;
        }
        double localCountsAdded = 0.0;
        for (IndexValue iv : x) {
            double v = iv.getValue();
            if (v < 0.0) continue;
            double[] dArray = this.wordCounts[targetClass];
            int n = iv.getIndex();
            dArray[n] = dArray[n] + v * weight;
            localCountsAdded += v * weight;
        }
        int n = targetClass;
        this.totalWords[n] = this.totalWords[n] + localCountsAdded;
        int n2 = targetClass;
        this.priors[n2] = this.priors[n2] + weight;
        this.priorSum += weight;
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        if (this.apriori == null) {
            throw new UntrainedModelException("Model has not been intialized");
        }
        CategoricalResults results = new CategoricalResults(this.apriori.length);
        double[] logProbs = new double[this.apriori.length];
        double maxLogProg = Double.NEGATIVE_INFINITY;
        Vec numVals = data.getNumericalValues();
        if (this.finalized) {
            for (int c = 0; c < this.priors.length; ++c) {
                double logProb = this.priors[c];
                double[] counts = this.wordCounts[c];
                for (IndexValue iv : numVals) {
                    logProb += iv.getValue() * counts[iv.getIndex()];
                }
                for (int j = 0; j < this.apriori[c].length; ++j) {
                    logProb += this.apriori[c][j][data.getCategoricalValue(j)];
                }
                logProbs[c] = logProb;
                maxLogProg = Math.max(maxLogProg, logProb);
            }
        } else {
            double priorSumSmooth = this.priorSum + (double)logProbs.length * this.smoothing;
            for (int c = 0; c < this.priors.length; ++c) {
                double logProb = Math.log((this.priors[c] + this.smoothing) / priorSumSmooth);
                double[] counts = this.wordCounts[c];
                double logTotalCounts = Math.log(this.totalWords[c] + this.smoothing * (double)counts.length);
                for (IndexValue iv : numVals) {
                    logProb += iv.getValue() * (Math.log(counts[iv.getIndex()] + this.smoothing) - logTotalCounts);
                }
                for (int j = 0; j < this.apriori[c].length; ++j) {
                    double sum = 0.0;
                    for (int z = 0; z < this.apriori[c][j].length; ++z) {
                        sum += this.apriori[c][j][z] + this.smoothing;
                    }
                    double p = this.apriori[c][j][data.getCategoricalValue(j)] + this.smoothing;
                    logProb += Math.log(p / sum);
                }
                logProbs[c] = logProb;
                maxLogProg = Math.max(maxLogProg, logProb);
            }
        }
        double denom = MathTricks.logSumExp(logProbs, maxLogProg);
        for (int i = 0; i < results.size(); ++i) {
            results.setProb(i, Math.exp(logProbs[i] - denom));
        }
        results.normalize();
        return results;
    }

    @Override
    public boolean supportsWeightedData() {
        return true;
    }
}

