package smile.classification;

import java.io.Serializable;
import java.util.Iterator;
import smile.math.SparseArray;
import smile.stat.distribution.Distribution;
import smile.swing.FontChooser;

/* loaded from: input_file:smile/classification/NaiveBayes.class */
public class NaiveBayes implements OnlineClassifier<double[]>, SoftClassifier<double[]>, Serializable {
    private static final long serialVersionUID = 1;
    private static final double EPSILON = 1.0E-20d;
    private Model model;
    private int k;
    private int p;
    private double[] priori;
    private Distribution[][] prob;
    private double sigma;
    private boolean predefinedPriori;
    private int n;
    private int[] nc;
    private int[] nt;
    private int[][] ntc;
    private double[][] condprob;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: smile.classification.NaiveBayes$1, reason: invalid class name */
    /* loaded from: input_file:smile/classification/NaiveBayes$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$smile$classification$NaiveBayes$Model = new int[Model.values().length];

        static {
            try {
                $SwitchMap$smile$classification$NaiveBayes$Model[Model.GENERAL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$smile$classification$NaiveBayes$Model[Model.MULTINOMIAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$smile$classification$NaiveBayes$Model[Model.BERNOULLI.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:smile/classification/NaiveBayes$Model.class */
    public enum Model {
        GENERAL,
        MULTINOMIAL,
        BERNOULLI
    }

    /* loaded from: input_file:smile/classification/NaiveBayes$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private Model model;
        private int k;
        private int p;
        private double[] priori;
        private double sigma = 1.0d;

        public Trainer(Model model, int i, int i2) {
            if (i < 2) {
                throw new IllegalArgumentException("Invalid number of classes: " + i);
            }
            if (i2 <= 0) {
                throw new IllegalArgumentException("Invalid dimension: " + i2);
            }
            this.model = model;
            this.k = i;
            this.p = i2;
        }

        public Trainer(Model model, double[] dArr, int i) {
            if (i <= 0) {
                throw new IllegalArgumentException("Invalid dimension: " + i);
            }
            if (dArr.length < 2) {
                throw new IllegalArgumentException("Invalid number of classes: " + dArr.length);
            }
            double d = 0.0d;
            for (double d2 : dArr) {
                if (d2 <= 0.0d || d2 >= 1.0d) {
                    throw new IllegalArgumentException("Invalid priori probability: " + d2);
                }
                d += d2;
            }
            if (Math.abs(d - 1.0d) > 1.0E-10d) {
                throw new IllegalArgumentException("The sum of priori probabilities is not one: " + d);
            }
            this.model = model;
            this.priori = dArr;
            this.k = dArr.length;
            this.p = i;
        }

        public Trainer setPriori(double[] dArr) {
            this.priori = dArr;
            return this;
        }

        public Trainer setSmooth(double d) {
            if (d < 0.0d) {
                throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + d);
            }
            this.sigma = d;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public NaiveBayes train(double[][] dArr, int[] iArr) {
            NaiveBayes naiveBayes = this.priori == null ? new NaiveBayes(this.model, this.k, this.p, this.sigma) : new NaiveBayes(this.model, this.priori, this.p, this.sigma);
            naiveBayes.learn(dArr, iArr);
            return naiveBayes;
        }
    }

    public NaiveBayes(double[] dArr, Distribution[][] distributionArr) {
        if (dArr.length != distributionArr.length) {
            throw new IllegalArgumentException("The number of priori probabilities and that of the classes are not same.");
        }
        double d = 0.0d;
        for (double d2 : dArr) {
            if (d2 <= 0.0d || d2 >= 1.0d) {
                throw new IllegalArgumentException("Invalid priori probability: " + d2);
            }
            d += d2;
        }
        if (Math.abs(d - 1.0d) > 1.0E-10d) {
            throw new IllegalArgumentException("The sum of priori probabilities is not one: " + d);
        }
        this.model = Model.GENERAL;
        this.k = dArr.length;
        this.p = distributionArr[0].length;
        this.priori = dArr;
        this.prob = distributionArr;
        this.predefinedPriori = true;
    }

    public NaiveBayes(Model model, int i, int i2) {
        this(model, i, i2, 1.0d);
    }

    public NaiveBayes(Model model, int i, int i2, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of classes: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid dimension: " + i2);
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + d);
        }
        this.model = model;
        this.k = i;
        this.p = i2;
        this.sigma = d;
        this.predefinedPriori = false;
        this.priori = new double[i];
        this.n = 0;
        this.nc = new int[i];
        this.nt = new int[i];
        this.ntc = new int[i][i2];
        this.condprob = new double[i][i2];
    }

    public NaiveBayes(Model model, double[] dArr, int i) {
        this(model, dArr, i, 1.0d);
    }

    public NaiveBayes(Model model, double[] dArr, int i, double d) {
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid dimension: " + i);
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid add-k smoothing parameter: " + d);
        }
        if (dArr.length < 2) {
            throw new IllegalArgumentException("Invalid number of classes: " + dArr.length);
        }
        double d2 = 0.0d;
        for (double d3 : dArr) {
            if (d3 <= 0.0d || d3 >= 1.0d) {
                throw new IllegalArgumentException("Invalid priori probability: " + d3);
            }
            d2 += d3;
        }
        if (Math.abs(d2 - 1.0d) > 1.0E-10d) {
            throw new IllegalArgumentException("The sum of priori probabilities is not one: " + d2);
        }
        this.model = model;
        this.k = dArr.length;
        this.p = i;
        this.sigma = d;
        this.priori = dArr;
        this.predefinedPriori = true;
        double d4 = 0.0d;
        for (int i2 = 0; i2 < this.k; i2++) {
            if (dArr[i2] <= 0.0d || dArr[i2] >= 1.0d) {
                throw new IllegalArgumentException("Invalid priori probability: " + dArr[i2]);
            }
            d4 += dArr[i2];
        }
        if (Math.abs(1.0d - d4) > 1.0E-5d) {
            throw new IllegalArgumentException("Priori probabilities don't sum to 1.");
        }
        this.n = 0;
        this.nc = new int[this.k];
        this.nt = new int[this.k];
        this.ntc = new int[this.k][i];
        this.condprob = new double[this.k][i];
    }

    public double[] getPriori() {
        return this.priori;
    }

    @Override // smile.classification.OnlineClassifier
    public void learn(double[] dArr, int i) {
        if (this.model == Model.GENERAL) {
            throw new UnsupportedOperationException("General-mode Naive Bayes classifier doesn't support online learning.");
        }
        if (dArr.length != this.p) {
            throw new IllegalArgumentException("Invalid input vector size: " + dArr.length);
        }
        if (this.model == Model.MULTINOMIAL) {
            for (int i2 = 0; i2 < this.p; i2++) {
                this.ntc[i][i2] = (int) (r0[r1] + dArr[i2]);
                this.nt[i] = (int) (r0[i] + dArr[i2]);
            }
        } else {
            for (int i3 = 0; i3 < this.p; i3++) {
                if (dArr[i3] > 0.0d) {
                    int[] iArr = this.ntc[i];
                    int i4 = i3;
                    iArr[i4] = iArr[i4] + 1;
                }
            }
        }
        this.n++;
        int[] iArr2 = this.nc;
        iArr2[i] = iArr2[i] + 1;
        update();
    }

    public void learn(SparseArray sparseArray, int i) {
        if (this.model == Model.GENERAL) {
            throw new UnsupportedOperationException("General-mode Naive Bayes classifier doesn't support online learning.");
        }
        if (this.model == Model.MULTINOMIAL) {
            Iterator<SparseArray.Entry> it = sparseArray.iterator();
            while (it.hasNext()) {
                SparseArray.Entry next = it.next();
                this.ntc[i][next.i] = (int) (r0[r1] + next.x);
                this.nt[i] = (int) (r0[i] + next.x);
            }
        } else {
            Iterator<SparseArray.Entry> it2 = sparseArray.iterator();
            while (it2.hasNext()) {
                SparseArray.Entry next2 = it2.next();
                if (next2.x > 0.0d) {
                    int[] iArr = this.ntc[i];
                    int i2 = next2.i;
                    iArr[i2] = iArr[i2] + 1;
                }
            }
        }
        this.n++;
        int[] iArr2 = this.nc;
        iArr2[i] = iArr2[i] + 1;
        update();
    }

    public void learn(double[][] dArr, int[] iArr) {
        if (this.model == Model.GENERAL) {
            throw new UnsupportedOperationException("General-mode Naive Bayes classifier doesn't support online learning.");
        }
        if (this.model == Model.MULTINOMIAL) {
            for (int i = 0; i < dArr.length; i++) {
                if (dArr[i].length != this.p) {
                    throw new IllegalArgumentException("Invalid input vector size: " + dArr[i].length);
                }
                for (int i2 = 0; i2 < this.p; i2++) {
                    this.ntc[iArr[i]][i2] = (int) (r0[r1] + dArr[i][i2]);
                    this.nt[iArr[i]] = (int) (r0[r1] + dArr[i][i2]);
                }
                this.n++;
                int[] iArr2 = this.nc;
                int i3 = iArr[i];
                iArr2[i3] = iArr2[i3] + 1;
            }
        } else {
            for (int i4 = 0; i4 < dArr.length; i4++) {
                if (dArr[i4].length != this.p) {
                    throw new IllegalArgumentException("Invalid input vector size: " + dArr[i4].length);
                }
                for (int i5 = 0; i5 < this.p; i5++) {
                    if (dArr[i4][i5] > 0.0d) {
                        int[] iArr3 = this.ntc[iArr[i4]];
                        int i6 = i5;
                        iArr3[i6] = iArr3[i6] + 1;
                    }
                }
                this.n++;
                int[] iArr4 = this.nc;
                int i7 = iArr[i4];
                iArr4[i7] = iArr4[i7] + 1;
            }
        }
        update();
    }

    private void update() {
        if (!this.predefinedPriori) {
            for (int i = 0; i < this.k; i++) {
                this.priori[i] = (this.nc[i] + EPSILON) / (this.n + (this.k * EPSILON));
            }
        }
        if (this.model == Model.MULTINOMIAL) {
            for (int i2 = 0; i2 < this.k; i2++) {
                for (int i3 = 0; i3 < this.p; i3++) {
                    this.condprob[i2][i3] = (this.ntc[i2][i3] + this.sigma) / (this.nt[i2] + (this.sigma * this.p));
                }
            }
            return;
        }
        for (int i4 = 0; i4 < this.k; i4++) {
            for (int i5 = 0; i5 < this.p; i5++) {
                this.condprob[i4][i5] = (this.ntc[i4][i5] + this.sigma) / (this.nc[i4] + (this.sigma * 2.0d));
            }
        }
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        return predict(dArr, (double[]) null);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d", Integer.valueOf(dArr.length)));
        }
        if (dArr2 != null && dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        boolean z = this.model == Model.GENERAL;
        for (int i2 = 0; i2 < this.k; i2++) {
            double log = Math.log(this.priori[i2]);
            for (int i3 = 0; i3 < this.p; i3++) {
                switch (AnonymousClass1.$SwitchMap$smile$classification$NaiveBayes$Model[this.model.ordinal()]) {
                    case FontChooser.CANCEL_OPTION /* 1 */:
                        log += this.prob[i2][i3].logp(dArr[i3]);
                        break;
                    case 2:
                        if (dArr[i3] > 0.0d) {
                            log += dArr[i3] * Math.log(this.condprob[i2][i3]);
                            z = true;
                            break;
                        } else {
                            break;
                        }
                    case 3:
                        if (dArr[i3] > 0.0d) {
                            log += Math.log(this.condprob[i2][i3]);
                            z = true;
                            break;
                        } else {
                            log += Math.log(1.0d - this.condprob[i2][i3]);
                            break;
                        }
                }
            }
            if (log > d && z) {
                d = log;
                i = i2;
            }
            if (dArr2 != null) {
                dArr2[i2] = log;
            }
        }
        if (dArr2 != null && z) {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.k; i4++) {
                dArr2[i4] = Math.exp(dArr2[i4] - d);
                d2 += dArr2[i4];
            }
            for (int i5 = 0; i5 < this.k; i5++) {
                int i6 = i5;
                dArr2[i6] = dArr2[i6] / d2;
            }
        }
        return i;
    }

    public int predict(SparseArray sparseArray) {
        return predict(sparseArray, (double[]) null);
    }

    public int predict(SparseArray sparseArray, double[] dArr) {
        if (dArr != null && dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        boolean z = this.model == Model.GENERAL;
        for (int i2 = 0; i2 < this.k; i2++) {
            double log = Math.log(this.priori[i2]);
            Iterator<SparseArray.Entry> it = sparseArray.iterator();
            while (it.hasNext()) {
                SparseArray.Entry next = it.next();
                switch (AnonymousClass1.$SwitchMap$smile$classification$NaiveBayes$Model[this.model.ordinal()]) {
                    case FontChooser.CANCEL_OPTION /* 1 */:
                        log += this.prob[i2][next.i].logp(next.x);
                        break;
                    case 2:
                        if (next.x > 0.0d) {
                            log += next.x * Math.log(this.condprob[i2][next.i]);
                            z = true;
                            break;
                        } else {
                            break;
                        }
                    case 3:
                        if (next.x > 0.0d) {
                            log += Math.log(this.condprob[i2][next.i]);
                            z = true;
                            break;
                        } else {
                            log += Math.log(1.0d - this.condprob[i2][next.i]);
                            break;
                        }
                }
            }
            if (log > d && z) {
                d = log;
                i = i2;
            }
            if (dArr != null) {
                dArr[i2] = log;
            }
        }
        if (dArr != null && z) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.k; i3++) {
                dArr[i3] = Math.exp(dArr[i3] - d);
                d2 += dArr[i3];
            }
            for (int i4 = 0; i4 < this.k; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] / d2;
            }
        }
        return i;
    }
}
