/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.classification.bayes;

import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.algorithm.estimator.DensityEstimator;
import org.jdmp.core.algorithm.estimator.GaussianDensityEstimator;
import org.jdmp.core.algorithm.estimator.GeneralDensityEstimator;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.util.MathUtil;

public class NaiveBayesClassifier
extends AbstractClassifier {
    private static final long serialVersionUID = -4962565315819543623L;
    private DensityEstimator[][] dists = null;
    private DensityEstimator[] classDists = null;
    private int classCount = -1;

    public NaiveBayesClassifier() {
    }

    public NaiveBayesClassifier(String inputLabel) {
        super(inputLabel);
    }

    @Override
    public Matrix predictOne(Matrix input) {
        int j;
        input = input.toColumnVector(Calculation.Ret.LINK);
        double[] probs = new double[this.classCount];
        double[] logs = new double[this.classCount];
        for (j = 0; j < this.classCount; ++j) {
            int n = j;
            logs[n] = logs[n] + Math.log(this.classDists[j].getProbability(1.0));
        }
        j = 0;
        while ((long)j < input.getColumnCount()) {
            int i;
            double probSum = 0.0;
            for (i = 0; i < this.classCount; ++i) {
                double probability;
                double value = input.getAsDouble(0L, j);
                probs[i] = probability = this.dists[j][i].getProbability(value);
                probSum += probability;
            }
            for (i = 0; i < this.classCount; ++i) {
                int n = i;
                logs[n] = logs[n] + Math.log(probs[i] / probSum);
            }
            ++j;
        }
        double[] finalProbs = MathUtil.logToProbs(logs);
        Matrix m = Matrix.Factory.linkToArray(finalProbs).transpose();
        return m;
    }

    @Override
    public void reset() {
        this.dists = null;
        this.classDists = null;
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        System.out.println("training started");
        int featureCount = (int)((Sample)dataSet.get(0)).getAsMatrix(this.getInputLabel()).getValueCount();
        boolean discrete = this.isDiscrete(dataSet);
        this.classCount = this.getClassCount(dataSet);
        this.dists = new DensityEstimator[featureCount][this.classCount];
        this.classDists = new DensityEstimator[this.classCount];
        for (int j = 0; j < this.classCount; ++j) {
            this.classDists[j] = new GeneralDensityEstimator();
            for (int i = 0; i < featureCount; ++i) {
                this.dists[i][j] = discrete ? new GeneralDensityEstimator() : new GaussianDensityEstimator();
            }
        }
        System.out.println("density estimators created");
        int count = 0;
        for (Sample s : dataSet) {
            Matrix sampleInput = s.getAsMatrix(this.getInputLabel()).toColumnVector(Calculation.Ret.LINK);
            Matrix sampleTarget = s.getAsMatrix(this.getTargetLabel()).toColumnVector(Calculation.Ret.LINK);
            double weight = s.getWeight();
            for (int j = 0; j < this.classCount; ++j) {
                double classValue = sampleTarget.getAsDouble(0L, j);
                if (classValue == 0.0) {
                    this.classDists[j].addValue(0.0, weight);
                    continue;
                }
                this.classDists[j].addValue(1.0, weight);
                int i = 0;
                while ((long)i < sampleInput.getColumnCount()) {
                    double inputValue = sampleInput.getAsDouble(0L, i);
                    this.dists[i][j].addValue(inputValue, weight);
                    ++i;
                }
            }
            if (++count % 10000 != 0) continue;
            System.out.println(count);
        }
        System.out.println("training finished");
    }

    @Override
    public Classifier emptyCopy() {
        NaiveBayesClassifier nb = new NaiveBayesClassifier();
        nb.setInputLabel(this.getInputLabel());
        nb.setTargetLabel(this.getTargetLabel());
        return nb;
    }
}

