package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.ConstantVector;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.MathTricks;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/classifiers/linear/StochasticMultinomialLogisticRegression.class */
public class StochasticMultinomialLogisticRegression implements Classifier, Parameterized, SimpleWeightVectorModel {
    private static final long serialVersionUID = -492707881682847556L;
    private int epochs;
    private boolean clipping;
    private double regularization;
    private double tolerance;
    private double initialLearningRate;
    private double alpha;
    private DecayRate learningRateDecay;
    private Prior prior;
    private boolean standardized;
    private boolean useBias;
    private int miniBatchSize;
    private Vec[] B;
    private double[] biases;

    /* loaded from: input_file:jsat/classifiers/linear/StochasticMultinomialLogisticRegression$Prior.class */
    public enum Prior {
        GAUSSIAN { // from class: jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior.1
            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2) {
                return (-d) / d2;
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2) {
                return ((-0.5d) * Math.log(6.283185307179586d * d2)) - ((((2.0d * d) * d) * d2) / 2.0d);
            }
        },
        LAPLACE { // from class: jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior.2
            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2) {
                return ((-Math.sqrt(2.0d)) * Math.signum(d)) / Math.sqrt(d2);
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2) {
                return ((((-Math.signum(d)) * Math.sqrt(2.0d)) * d) / Math.sqrt(d2)) - (0.5d * Math.log(2.0d * d2));
            }
        },
        ELASTIC { // from class: jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior.3
            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2) {
                throw new UnsupportedOperationException();
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2, double d3) {
                return (d3 * LAPLACE.gradientError(d, d2)) + ((1.0d - d3) * GAUSSIAN.gradientError(d, d2));
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2) {
                return Double.NaN;
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2, double d3) {
                return (d3 * LAPLACE.logProb(d, d2)) + ((1.0d - d3) * GAUSSIAN.logProb(d, d2));
            }
        },
        CAUCHY { // from class: jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior.4
            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2) {
                throw new UnsupportedOperationException();
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2, double d3) {
                return ((-2.0d) * d) / ((d * d) + (d3 * d3));
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2) {
                return Double.NaN;
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2, double d3) {
                return ((-Math.log(3.141592653589793d)) + Math.log(d3)) - Math.log((d * d) + (d3 * d3));
            }
        },
        UNIFORM { // from class: jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior.5
            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double gradientError(double d, double d2) {
                return 0.0d;
            }

            @Override // jsat.classifiers.linear.StochasticMultinomialLogisticRegression.Prior
            protected double logProb(double d, double d2) {
                return 0.0d;
            }
        };

        protected abstract double gradientError(double d, double d2);

        protected double gradientError(double d, double d2, double d3) {
            return gradientError(d, d2);
        }

        protected abstract double logProb(double d, double d2);

        protected double logProb(double d, double d2, double d3) {
            return logProb(d, d2);
        }
    }

    public StochasticMultinomialLogisticRegression(double d, int i, double d2, Prior prior) {
        this.clipping = true;
        this.tolerance = 1.0E-4d;
        this.alpha = 0.5d;
        this.learningRateDecay = new ExponetialDecay();
        this.standardized = true;
        this.useBias = true;
        this.miniBatchSize = 1;
        setEpochs(i);
        setRegularization(d2);
        setInitialLearningRate(d);
        setPrior(prior);
    }

    public StochasticMultinomialLogisticRegression(double d, int i) {
        this(d, i, 1.0E-6d, Prior.GAUSSIAN);
    }

    public StochasticMultinomialLogisticRegression() {
        this(0.1d, 50);
    }

    protected StochasticMultinomialLogisticRegression(StochasticMultinomialLogisticRegression stochasticMultinomialLogisticRegression) {
        this.clipping = true;
        this.tolerance = 1.0E-4d;
        this.alpha = 0.5d;
        this.learningRateDecay = new ExponetialDecay();
        this.standardized = true;
        this.useBias = true;
        this.miniBatchSize = 1;
        this.epochs = stochasticMultinomialLogisticRegression.epochs;
        this.clipping = stochasticMultinomialLogisticRegression.clipping;
        this.regularization = stochasticMultinomialLogisticRegression.regularization;
        this.tolerance = stochasticMultinomialLogisticRegression.tolerance;
        this.initialLearningRate = stochasticMultinomialLogisticRegression.initialLearningRate;
        this.alpha = stochasticMultinomialLogisticRegression.alpha;
        this.learningRateDecay = stochasticMultinomialLogisticRegression.learningRateDecay;
        this.prior = stochasticMultinomialLogisticRegression.prior;
        this.standardized = stochasticMultinomialLogisticRegression.standardized;
        if (stochasticMultinomialLogisticRegression.B != null) {
            this.B = new Vec[stochasticMultinomialLogisticRegression.B.length];
            for (int i = 0; i < stochasticMultinomialLogisticRegression.B.length; i++) {
                this.B[i] = stochasticMultinomialLogisticRegression.B[i].mo46clone();
            }
        }
        if (stochasticMultinomialLogisticRegression.biases != null) {
            this.biases = Arrays.copyOf(stochasticMultinomialLogisticRegression.biases, stochasticMultinomialLogisticRegression.biases.length);
        }
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setEpochs(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of epochs must be positive");
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public void setAlpha(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Extra parameter must be non negative, not " + d);
        }
        this.alpha = d;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setClipping(boolean z) {
        this.clipping = z;
    }

    public boolean isClipping() {
        return this.clipping;
    }

    public void setInitialLearningRate(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Learning rate must be a positive constant, not " + d);
        }
        this.initialLearningRate = d;
    }

    public double getInitialLearningRate() {
        return this.initialLearningRate;
    }

    public void setLearningRateDecay(DecayRate decayRate) {
        this.learningRateDecay = decayRate;
    }

    public DecayRate getLearningRateDecay() {
        return this.learningRateDecay;
    }

    public void setRegularization(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Regualrization must be a non negative constant, not " + d);
        }
        this.regularization = d;
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setPrior(Prior prior) {
        this.prior = prior;
    }

    public Prior getPrior() {
        return this.prior;
    }

    public void setTolerance(double d) {
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setStandardized(boolean z) {
        this.standardized = z;
    }

    public boolean isStandardized() {
        return this.standardized;
    }

    public void setMiniBatchSize(int i) {
        this.miniBatchSize = i;
    }

    public int getMiniBatchSize() {
        return this.miniBatchSize;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        return i == this.B.length ? new ConstantVector(0.0d, this.B[0].length()) : this.B[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i == this.biases.length) {
            return 0.0d;
        }
        return this.biases[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return this.B.length + 1;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.B == null) {
            throw new UntrainedModelException("Model has not yet been trained");
        }
        Vec numericalValues = dataPoint.getNumericalValues();
        double[] dArr = new double[this.B.length + 1];
        for (int i = 0; i < this.B.length; i++) {
            dArr[i] = numericalValues.dot(this.B[i]) + this.biases[i];
        }
        dArr[this.B.length] = 1.0d;
        MathTricks.softmax(dArr, false);
        return new CategoricalResults(dArr);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet, boolean z) {
        train(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void train(ClassificationDataSet classificationDataSet) {
        double d;
        double d2;
        double logProb;
        int sampleSize = classificationDataSet.getSampleSize();
        double d3 = sampleSize;
        int numNumericalVars = classificationDataSet.getNumNumericalVars();
        if (numNumericalVars < 1) {
            throw new FailedToFitException("Data set has no numeric attributes to train on");
        }
        this.B = new Vec[classificationDataSet.getClassSize() - 1];
        this.biases = new double[this.B.length];
        for (int i = 0; i < this.B.length; i++) {
            this.B[i] = new DenseVector(numNumericalVars);
        }
        IntList intList = new IntList(sampleSize);
        ListUtils.addRange(intList, 0, sampleSize, 1);
        Vec vec = null;
        Vec vec2 = null;
        if (this.standardized) {
            Vec[] columnMeanVariance = classificationDataSet.getColumnMeanVariance();
            vec = columnMeanVariance[0];
            vec2 = columnMeanVariance[1];
            vec2.applyFunction(Math::sqrt);
            vec.pairwiseDivide(vec2);
            vec2.applyFunction(d4 -> {
                return 1.0d / d4;
            });
        }
        double[] dArr = new double[this.B.length];
        int[] iArr = new int[numNumericalVars];
        int i2 = 0;
        double d5 = Double.POSITIVE_INFINITY;
        for (int i3 = 0; i3 < this.epochs; i3++) {
            Collections.shuffle(intList);
            double d6 = 0.0d;
            double rate = this.learningRateDecay.rate(i3, this.epochs, this.initialLearningRate);
            double d7 = this.regularization * rate;
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 >= intList.size()) {
                    break;
                }
                int min = Math.min(this.miniBatchSize, intList.size() - i5);
                double d8 = 1.0d / min;
                for (int i6 = 0; i6 < min; i6++) {
                    int intValue = intList.get(i5 + i6).intValue();
                    int dataPointCategory = classificationDataSet.getDataPointCategory(intValue);
                    Vec numericalValues = classificationDataSet.getDataPoint(intValue).getNumericalValues();
                    for (int i7 = 0; i7 < this.B.length; i7++) {
                        dArr[i7] = numericalValues.dot(this.B[i7]) + this.biases[i7];
                    }
                    MathTricks.softmax(dArr, true);
                    if (this.prior != Prior.UNIFORM) {
                        Iterator<IndexValue> it = numericalValues.iterator();
                        while (it.hasNext()) {
                            int index = it.next().getIndex();
                            if (iArr[index] != 0) {
                                double d9 = (d7 * (iArr[index] - i2)) / d3;
                                for (Vec vec3 : this.B) {
                                    double d10 = vec3.get(index);
                                    double gradientError = this.standardized ? d10 + (d9 * this.prior.gradientError((d10 * vec2.get(index)) - vec.get(index), 1.0d, this.alpha)) : d10 + (d9 * this.prior.gradientError(d10, 1.0d, this.alpha));
                                    if (!this.clipping || Math.signum(d10) == Math.signum(gradientError)) {
                                        vec3.set(index, gradientError);
                                    } else {
                                        vec3.set(index, 0.0d);
                                    }
                                }
                                iArr[index] = i2;
                            }
                        }
                    }
                    int i8 = 0;
                    while (i8 < this.B.length) {
                        Vec vec4 = this.B[i8];
                        double d11 = dArr[i8];
                        double log = Math.log(d11);
                        if (!Double.isInfinite(log)) {
                            d6 += log;
                        }
                        double d12 = (i8 == dataPointCategory ? 1 : 0) - d11;
                        vec4.mutableAdd(d8 * rate * d12, numericalValues);
                        if (this.useBias) {
                            double[] dArr2 = this.biases;
                            int i9 = i8;
                            dArr2[i9] = dArr2[i9] + (d8 * rate * d12) + (d7 * this.prior.gradientError(this.biases[i8] - 1.0d, 1.0d, this.alpha));
                        }
                        i8++;
                    }
                }
                i2++;
                i4 = i5 + this.miniBatchSize;
            }
            double d13 = d6 * (-1.0d);
            if (this.prior != Prior.UNIFORM) {
                for (int i10 = 0; i10 < numNumericalVars; i10++) {
                    if (iArr[i10] - i2 == 0) {
                        for (Vec vec5 : this.B) {
                            if (this.standardized) {
                                d = d13;
                                d2 = this.regularization;
                                logProb = this.prior.logProb((vec5.get(i10) * vec2.get(i10)) - vec.get(i10), 1.0d, this.alpha);
                            } else {
                                d = d13;
                                d2 = this.regularization;
                                logProb = this.prior.logProb(vec5.get(i10), 1.0d, this.alpha);
                            }
                            d13 = d + (d2 * logProb);
                        }
                    } else {
                        double d14 = (d7 * (iArr[i10] - i2)) / d3;
                        for (Vec vec6 : this.B) {
                            double d15 = vec6.get(i10);
                            if (d15 != 0.0d) {
                                double gradientError2 = this.standardized ? d15 + (d14 * this.prior.gradientError((d15 * vec2.get(i10)) - vec.get(i10), 1.0d, this.alpha)) : d15 + (d14 * this.prior.gradientError(d15, 1.0d, this.alpha));
                                if (!this.clipping || Math.signum(d15) == Math.signum(gradientError2)) {
                                    vec6.set(i10, gradientError2);
                                } else {
                                    vec6.set(i10, 0.0d);
                                }
                                d13 = this.standardized ? d13 + (this.regularization * this.prior.logProb((vec6.get(i10) * vec2.get(i10)) - vec.get(i10), 1.0d, this.alpha)) : d13 + (this.regularization * this.prior.logProb(vec6.get(i10), 1.0d, this.alpha));
                            }
                        }
                        iArr[i10] = i2;
                    }
                }
            }
            if (Math.abs(d5 - d13) / (Math.abs(d5) + Math.abs(d13)) < this.tolerance) {
                return;
            }
            d5 = d13;
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    public Vec getCoefficientVector(int i) {
        return this.B[i];
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier m57clone() {
        return new StochasticMultinomialLogisticRegression(this);
    }
}
