package org.jdmp.core.algorithm.classification.bayes;

import java.util.Iterator;
import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.algorithm.estimator.DensityEstimator;
import org.jdmp.core.algorithm.estimator.GaussianDensityEstimator;
import org.jdmp.core.algorithm.estimator.GeneralDensityEstimator;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.dataset.MatrixDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.util.MathUtil;

/* loaded from: input_file:org/jdmp/core/algorithm/classification/bayes/NaiveBayesClassifier.class */
public class NaiveBayesClassifier extends AbstractClassifier {
    private static final long serialVersionUID = -4962565315819543623L;
    private DensityEstimator[][] dists;
    private DensityEstimator[] classDists;
    private int classCount;

    public NaiveBayesClassifier() {
        this.dists = (DensityEstimator[][]) null;
        this.classDists = null;
        this.classCount = -1;
    }

    public NaiveBayesClassifier(String str) {
        super(str);
        this.dists = (DensityEstimator[][]) null;
        this.classDists = null;
        this.classCount = -1;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Matrix predictOne(Matrix matrix) {
        Matrix columnVector = matrix.toColumnVector(Calculation.Ret.LINK);
        double[] dArr = new double[this.classCount];
        double[] dArr2 = new double[this.classCount];
        for (int i = 0; i < this.classCount; i++) {
            int i2 = i;
            dArr2[i2] = dArr2[i2] + Math.log(this.classDists[i].getProbability(1.0d));
        }
        for (int i3 = 0; i3 < columnVector.getColumnCount(); i3++) {
            double d = 0.0d;
            for (int i4 = 0; i4 < this.classCount; i4++) {
                double probability = this.dists[i3][i4].getProbability(columnVector.getAsDouble(new long[]{0, i3}));
                dArr[i4] = probability;
                d += probability;
            }
            for (int i5 = 0; i5 < this.classCount; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] + Math.log(dArr[i5] / d);
            }
        }
        return Matrix.Factory.linkToArray(MathUtil.logToProbs(dArr2)).transpose();
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void reset() {
        this.dists = (DensityEstimator[][]) null;
        this.classDists = null;
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void trainAll(ListDataSet listDataSet) {
        System.out.println("training started");
        int valueCount = (int) ((Sample) listDataSet.get(0)).getAsMatrix(getInputLabel()).getValueCount();
        boolean isDiscrete = isDiscrete(listDataSet);
        this.classCount = getClassCount(listDataSet);
        this.dists = new DensityEstimator[valueCount][this.classCount];
        this.classDists = new DensityEstimator[this.classCount];
        for (int i = 0; i < this.classCount; i++) {
            this.classDists[i] = new GeneralDensityEstimator();
            for (int i2 = 0; i2 < valueCount; i2++) {
                if (isDiscrete) {
                    this.dists[i2][i] = new GeneralDensityEstimator();
                } else {
                    this.dists[i2][i] = new GaussianDensityEstimator();
                }
            }
        }
        System.out.println("density estimators created");
        int i3 = 0;
        Iterator it = listDataSet.iterator();
        while (it.hasNext()) {
            Sample sample = (Sample) it.next();
            Matrix columnVector = sample.getAsMatrix(getInputLabel()).toColumnVector(Calculation.Ret.LINK);
            Matrix columnVector2 = sample.getAsMatrix(getTargetLabel()).toColumnVector(Calculation.Ret.LINK);
            double weight = sample.getWeight();
            for (int i4 = 0; i4 < this.classCount; i4++) {
                if (columnVector2.getAsDouble(new long[]{0, i4}) == 0.0d) {
                    this.classDists[i4].addValue(0.0d, weight);
                } else {
                    this.classDists[i4].addValue(1.0d, weight);
                    for (int i5 = 0; i5 < columnVector.getColumnCount(); i5++) {
                        this.dists[i5][i4].addValue(columnVector.getAsDouble(new long[]{0, i5}), weight);
                    }
                }
            }
            i3++;
            if (i3 % MatrixDataSet.MAXSAMPLES == 0) {
                System.out.println(i3);
            }
        }
        System.out.println("training finished");
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Classifier emptyCopy() {
        NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
        naiveBayesClassifier.setInputLabel(getInputLabel());
        naiveBayesClassifier.setTargetLabel(getTargetLabel());
        return naiveBayesClassifier;
    }
}
