package smile.classification;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.Callable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.Math;
import smile.util.MulticoreExecutor;

/* loaded from: input_file:smile/classification/Maxent.class */
public class Maxent implements SoftClassifier<int[]>, Serializable {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger(Maxent.class);
    private int p;
    private int k;
    private double L;
    private double[] w;
    private double[][] W;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/Maxent$BinaryObjectiveFunction.class */
    public static class BinaryObjectiveFunction implements DifferentiableMultivariateFunction {
        int[][] x;
        int[] y;
        double lambda;

        /* loaded from: input_file:smile/classification/Maxent$BinaryObjectiveFunction$FTask.class */
        class FTask implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(double[] dArr, int i, int i2) {
                this.w = dArr;
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Double call() {
                double d = 0.0d;
                for (int i = this.start; i < this.end; i++) {
                    double dot = Maxent.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    d += Maxent.log1pe(dot) - (BinaryObjectiveFunction.this.y[i] * dot);
                }
                return Double.valueOf(d);
            }
        }

        /* loaded from: input_file:smile/classification/Maxent$BinaryObjectiveFunction$GTask.class */
        class GTask implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(double[] dArr, int i, int i2) {
                this.w = dArr;
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public double[] call() {
                double d = 0.0d;
                int length = this.w.length - 1;
                double[] dArr = new double[this.w.length + 1];
                for (int i = this.start; i < this.end; i++) {
                    double dot = Maxent.dot(BinaryObjectiveFunction.this.x[i], this.w);
                    d += Maxent.log1pe(dot) - (BinaryObjectiveFunction.this.y[i] * dot);
                    double logistic = BinaryObjectiveFunction.this.y[i] - Math.logistic(dot);
                    for (int i2 : BinaryObjectiveFunction.this.x[i]) {
                        dArr[i2] = dArr[i2] - (logistic * i2);
                    }
                    dArr[length] = dArr[length] - logistic;
                }
                dArr[this.w.length] = d;
                return dArr;
            }
        }

        BinaryObjectiveFunction(int[][] iArr, int[] iArr2, double d) {
            this.x = iArr;
            this.y = iArr2;
            this.lambda = d;
        }

        @Override // smile.math.MultivariateFunction
        public double f(double[] dArr) {
            double d = 0.0d;
            int length = dArr.length - 1;
            int length2 = this.x.length;
            int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
            if (length2 < 1000 || threadPoolSize < 2) {
                for (int i = 0; i < length2; i++) {
                    double dot = Maxent.dot(this.x[i], dArr);
                    d += Maxent.log1pe(dot) - (this.y[i] * dot);
                }
            } else {
                ArrayList arrayList = new ArrayList(threadPoolSize + 1);
                int i2 = length2 / threadPoolSize;
                if (i2 < 100) {
                    i2 = 100;
                }
                int i3 = 0;
                int i4 = i2;
                for (int i5 = 0; i5 < threadPoolSize - 1; i5++) {
                    arrayList.add(new FTask(dArr, i3, i4));
                    i3 += i2;
                    i4 += i2;
                }
                arrayList.add(new FTask(dArr, i3, length2));
                try {
                    Iterator it = MulticoreExecutor.run(arrayList).iterator();
                    while (it.hasNext()) {
                        d += ((Double) it.next()).doubleValue();
                    }
                } catch (Exception e) {
                    for (int i6 = 0; i6 < length2; i6++) {
                        double dot2 = Maxent.dot(this.x[i6], dArr);
                        d += Maxent.log1pe(dot2) - (this.y[i6] * dot2);
                    }
                }
            }
            if (this.lambda != 0.0d) {
                double d2 = 0.0d;
                for (int i7 = 0; i7 < length; i7++) {
                    d2 += dArr[i7] * dArr[i7];
                }
                d += 0.5d * this.lambda * d2;
            }
            return d;
        }

        @Override // smile.math.DifferentiableMultivariateFunction
        public double f(double[] dArr, double[] dArr2) {
            double d = 0.0d;
            int length = dArr.length - 1;
            Arrays.fill(dArr2, 0.0d);
            int length2 = this.x.length;
            int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
            if (length2 < 1000 || threadPoolSize < 2) {
                for (int i = 0; i < length2; i++) {
                    double dot = Maxent.dot(this.x[i], dArr);
                    d += Maxent.log1pe(dot) - (this.y[i] * dot);
                    double logistic = this.y[i] - Math.logistic(dot);
                    for (int i2 : this.x[i]) {
                        dArr2[i2] = dArr2[i2] - (logistic * i2);
                    }
                    dArr2[length] = dArr2[length] - logistic;
                }
            } else {
                ArrayList arrayList = new ArrayList(threadPoolSize + 1);
                int i3 = length2 / threadPoolSize;
                if (i3 < 100) {
                    i3 = 100;
                }
                int i4 = 0;
                int i5 = i3;
                for (int i6 = 0; i6 < threadPoolSize - 1; i6++) {
                    arrayList.add(new GTask(dArr, i4, i5));
                    i4 += i3;
                    i5 += i3;
                }
                arrayList.add(new GTask(dArr, i4, length2));
                try {
                    for (double[] dArr3 : MulticoreExecutor.run(arrayList)) {
                        d += dArr3[dArr.length];
                        for (int i7 = 0; i7 < dArr.length; i7++) {
                            int i8 = i7;
                            dArr2[i8] = dArr2[i8] + dArr3[i7];
                        }
                    }
                } catch (Exception e) {
                    for (int i9 = 0; i9 < length2; i9++) {
                        double dot2 = Maxent.dot(this.x[i9], dArr);
                        d += Maxent.log1pe(dot2) - (this.y[i9] * dot2);
                        double logistic2 = this.y[i9] - Math.logistic(dot2);
                        for (int i10 : this.x[i9]) {
                            dArr2[i10] = dArr2[i10] - (logistic2 * i10);
                        }
                        dArr2[length] = dArr2[length] - logistic2;
                    }
                }
            }
            if (this.lambda != 0.0d) {
                double d2 = 0.0d;
                for (int i11 = 0; i11 < length; i11++) {
                    d2 += dArr[i11] * dArr[i11];
                }
                d += 0.5d * this.lambda * d2;
                for (int i12 = 0; i12 < length; i12++) {
                    int i13 = i12;
                    dArr2[i13] = dArr2[i13] + (this.lambda * dArr[i12]);
                }
            }
            return d;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/classification/Maxent$MultiClassObjectiveFunction.class */
    public static class MultiClassObjectiveFunction implements DifferentiableMultivariateFunction {
        int[][] x;
        int[] y;
        int k;
        int p;
        double lambda;

        /* loaded from: input_file:smile/classification/Maxent$MultiClassObjectiveFunction$FTask.class */
        class FTask implements Callable<Double> {
            double[] w;
            int start;
            int end;

            FTask(double[] dArr, int i, int i2) {
                this.w = dArr;
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Double call() {
                double d = 0.0d;
                double[] dArr = new double[MultiClassObjectiveFunction.this.k];
                for (int i = this.start; i < this.end; i++) {
                    for (int i2 = 0; i2 < MultiClassObjectiveFunction.this.k; i2++) {
                        dArr[i2] = Maxent.dot(MultiClassObjectiveFunction.this.x[i], this.w, i2, MultiClassObjectiveFunction.this.p);
                    }
                    Maxent.softmax(dArr);
                    d -= Maxent.log(dArr[MultiClassObjectiveFunction.this.y[i]]);
                }
                return Double.valueOf(d);
            }
        }

        /* loaded from: input_file:smile/classification/Maxent$MultiClassObjectiveFunction$GTask.class */
        class GTask implements Callable<double[]> {
            double[] w;
            int start;
            int end;

            GTask(double[] dArr, int i, int i2) {
                this.w = dArr;
                this.start = i;
                this.end = i2;
            }

            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public double[] call() {
                double d = 0.0d;
                double[] dArr = new double[MultiClassObjectiveFunction.this.k];
                double[] dArr2 = new double[this.w.length + 1];
                for (int i = this.start; i < this.end; i++) {
                    for (int i2 = 0; i2 < MultiClassObjectiveFunction.this.k; i2++) {
                        dArr[i2] = Maxent.dot(MultiClassObjectiveFunction.this.x[i], this.w, i2, MultiClassObjectiveFunction.this.p);
                    }
                    Maxent.softmax(dArr);
                    d -= Maxent.log(dArr[MultiClassObjectiveFunction.this.y[i]]);
                    int i3 = 0;
                    while (i3 < MultiClassObjectiveFunction.this.k) {
                        double d2 = (MultiClassObjectiveFunction.this.y[i] == i3 ? 1.0d : 0.0d) - dArr[i3];
                        int i4 = i3 * (MultiClassObjectiveFunction.this.p + 1);
                        for (int i5 : MultiClassObjectiveFunction.this.x[i]) {
                            int i6 = i4 + i5;
                            dArr2[i6] = dArr2[i6] - d2;
                        }
                        int i7 = i4 + MultiClassObjectiveFunction.this.p;
                        dArr2[i7] = dArr2[i7] - d2;
                        i3++;
                    }
                }
                dArr2[this.w.length] = d;
                return dArr2;
            }
        }

        MultiClassObjectiveFunction(int[][] iArr, int[] iArr2, int i, int i2, double d) {
            this.x = iArr;
            this.y = iArr2;
            this.k = i;
            this.p = i2;
            this.lambda = d;
        }

        @Override // smile.math.MultivariateFunction
        public double f(double[] dArr) {
            double d = 0.0d;
            double[] dArr2 = new double[this.k];
            int length = this.x.length;
            int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
            if (length < 1000 || threadPoolSize < 2) {
                for (int i = 0; i < length; i++) {
                    for (int i2 = 0; i2 < this.k; i2++) {
                        dArr2[i2] = Maxent.dot(this.x[i], dArr, i2, this.p);
                    }
                    Maxent.softmax(dArr2);
                    d -= Maxent.log(dArr2[this.y[i]]);
                }
            } else {
                ArrayList arrayList = new ArrayList(threadPoolSize + 1);
                int i3 = length / threadPoolSize;
                if (i3 < 100) {
                    i3 = 100;
                }
                int i4 = 0;
                int i5 = i3;
                for (int i6 = 0; i6 < threadPoolSize - 1; i6++) {
                    arrayList.add(new FTask(dArr, i4, i5));
                    i4 += i3;
                    i5 += i3;
                }
                arrayList.add(new FTask(dArr, i4, length));
                try {
                    Iterator it = MulticoreExecutor.run(arrayList).iterator();
                    while (it.hasNext()) {
                        d += ((Double) it.next()).doubleValue();
                    }
                } catch (Exception e) {
                    for (int i7 = 0; i7 < length; i7++) {
                        for (int i8 = 0; i8 < this.k; i8++) {
                            dArr2[i8] = Maxent.dot(this.x[i7], dArr, i8, this.p);
                        }
                        Maxent.softmax(dArr2);
                        d -= Maxent.log(dArr2[this.y[i7]]);
                    }
                }
            }
            if (this.lambda != 0.0d) {
                double d2 = 0.0d;
                for (int i9 = 0; i9 < this.k; i9++) {
                    for (int i10 = 0; i10 < this.p; i10++) {
                        d2 += Math.sqr(dArr[(i9 * (this.p + 1)) + i10]);
                    }
                }
                d += 0.5d * this.lambda * d2;
            }
            return d;
        }

        @Override // smile.math.DifferentiableMultivariateFunction
        public double f(double[] dArr, double[] dArr2) {
            double d = 0.0d;
            double[] dArr3 = new double[this.k];
            Arrays.fill(dArr2, 0.0d);
            int length = this.x.length;
            int threadPoolSize = MulticoreExecutor.getThreadPoolSize();
            if (length < 1000 || threadPoolSize < 2) {
                for (int i = 0; i < length; i++) {
                    for (int i2 = 0; i2 < this.k; i2++) {
                        dArr3[i2] = Maxent.dot(this.x[i], dArr, i2, this.p);
                    }
                    Maxent.softmax(dArr3);
                    d -= Maxent.log(dArr3[this.y[i]]);
                    int i3 = 0;
                    while (i3 < this.k) {
                        double d2 = (this.y[i] == i3 ? 1.0d : 0.0d) - dArr3[i3];
                        int i4 = i3 * (this.p + 1);
                        for (int i5 : this.x[i]) {
                            int i6 = i4 + i5;
                            dArr2[i6] = dArr2[i6] - d2;
                        }
                        int i7 = i4 + this.p;
                        dArr2[i7] = dArr2[i7] - d2;
                        i3++;
                    }
                }
            } else {
                ArrayList arrayList = new ArrayList(threadPoolSize + 1);
                int i8 = length / threadPoolSize;
                if (i8 < 100) {
                    i8 = 100;
                }
                int i9 = 0;
                int i10 = i8;
                for (int i11 = 0; i11 < threadPoolSize - 1; i11++) {
                    arrayList.add(new GTask(dArr, i9, i10));
                    i9 += i8;
                    i10 += i8;
                }
                arrayList.add(new GTask(dArr, i9, length));
                try {
                    for (double[] dArr4 : MulticoreExecutor.run(arrayList)) {
                        d += dArr4[dArr.length];
                        for (int i12 = 0; i12 < dArr.length; i12++) {
                            int i13 = i12;
                            dArr2[i13] = dArr2[i13] + dArr4[i12];
                        }
                    }
                } catch (Exception e) {
                    for (int i14 = 0; i14 < length; i14++) {
                        for (int i15 = 0; i15 < this.k; i15++) {
                            dArr3[i15] = Maxent.dot(this.x[i14], dArr, i15, this.p);
                        }
                        Maxent.softmax(dArr3);
                        d -= Maxent.log(dArr3[this.y[i14]]);
                        int i16 = 0;
                        while (i16 < this.k) {
                            double d3 = (this.y[i14] == i16 ? 1.0d : 0.0d) - dArr3[i16];
                            int i17 = i16 * (this.p + 1);
                            for (int i18 : this.x[i14]) {
                                int i19 = i17 + i18;
                                dArr2[i19] = dArr2[i19] - d3;
                            }
                            int i20 = i17 + this.p;
                            dArr2[i20] = dArr2[i20] - d3;
                            i16++;
                        }
                    }
                }
            }
            if (this.lambda != 0.0d) {
                double d4 = 0.0d;
                for (int i21 = 0; i21 < this.k; i21++) {
                    for (int i22 = 0; i22 < this.p; i22++) {
                        int i23 = (i21 * (this.p + 1)) + i22;
                        d4 += dArr[i23] * dArr[i23];
                        dArr2[i23] = dArr2[i23] + (this.lambda * dArr[i23]);
                    }
                }
                d += 0.5d * this.lambda * d4;
            }
            return d;
        }
    }

    /* loaded from: input_file:smile/classification/Maxent$Trainer.class */
    public static class Trainer extends ClassifierTrainer<int[]> {
        private int p;
        private double lambda = 0.0d;
        private double tol = 1.0E-5d;
        private int maxIter = 500;

        public Trainer(int i) {
            if (i < 0) {
                throw new IllegalArgumentException("Invalid dimension: " + i);
            }
            this.p = i;
        }

        public Trainer setRegularizationFactor(double d) {
            this.lambda = d;
            return this;
        }

        public Trainer setTolerance(double d) {
            this.tol = d;
            return this;
        }

        public Trainer setMaxNumIteration(int i) {
            this.maxIter = i;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public Maxent train(int[][] iArr, int[] iArr2) {
            return new Maxent(this.p, iArr, iArr2, this.lambda, this.tol, this.maxIter);
        }
    }

    public Maxent(int i, int[][] iArr, int[] iArr2) {
        this(i, iArr, iArr2, 0.1d);
    }

    public Maxent(int i, int[][] iArr, int[] iArr2, double d) {
        this(i, iArr, iArr2, d, 1.0E-5d, 500);
    }

    public Maxent(int i, int[][] iArr, int[] iArr2, double d, double d2, int i2) {
        if (iArr.length != iArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(iArr.length), Integer.valueOf(iArr2.length)));
        }
        if (i < 0) {
            throw new IllegalArgumentException("Invalid dimension: " + i);
        }
        if (d < 0.0d) {
            throw new IllegalArgumentException("Invalid regularization factor: " + d);
        }
        if (d2 <= 0.0d) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        this.p = i;
        int[] unique = Math.unique(iArr2);
        Arrays.sort(unique);
        for (int i3 = 0; i3 < unique.length; i3++) {
            if (unique[i3] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i3]);
            }
            if (i3 > 0 && unique[i3] - unique[i3 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i3] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (this.k == 2) {
            BinaryObjectiveFunction binaryObjectiveFunction = new BinaryObjectiveFunction(iArr, iArr2, d);
            this.w = new double[i + 1];
            this.L = 0.0d;
            try {
                this.L = -Math.min(binaryObjectiveFunction, 5, this.w, d2, i2);
                return;
            } catch (Exception e) {
                logger.error("Failed to minimize binary objective function of Maximum Entropy Classifier", e);
                return;
            }
        }
        MultiClassObjectiveFunction multiClassObjectiveFunction = new MultiClassObjectiveFunction(iArr, iArr2, this.k, i, d);
        this.w = new double[this.k * (i + 1)];
        this.L = 0.0d;
        try {
            this.L = -Math.min(multiClassObjectiveFunction, 5, this.w, d2, i2);
        } catch (Exception e2) {
            logger.error("Failed to minimize multi-class objective function of Maximum Entropy Classifier", e2);
        }
        this.W = new double[this.k][i + 1];
        int i4 = 0;
        for (int i5 = 0; i5 < this.k; i5++) {
            int i6 = 0;
            while (i6 <= i) {
                this.W[i5][i6] = this.w[i4];
                i6++;
                i4++;
            }
        }
        this.w = null;
    }

    public int getDimension() {
        return this.p;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double log1pe(double d) {
        return d > 15.0d ? d : 0.0d + Math.log1p(Math.exp(d));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double log(double d) {
        return d < 1.0E-300d ? -690.7755d : Math.log(d);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void softmax(double[] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > d) {
                d = dArr[i];
            }
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            double exp = Math.exp(dArr[i2] - d);
            dArr[i2] = exp;
            d2 += exp;
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(int[] iArr, double[] dArr) {
        double d = dArr[dArr.length - 1];
        for (int i : iArr) {
            d += dArr[i];
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double dot(int[] iArr, double[] dArr, int i, int i2) {
        int i3 = i * (i2 + 1);
        double d = dArr[i3 + i2];
        for (int i4 : iArr) {
            d += dArr[i3 + i4];
        }
        return d;
    }

    public double loglikelihood() {
        return this.L;
    }

    @Override // smile.classification.Classifier
    public int predict(int[] iArr) {
        return predict(iArr, (double[]) null);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(int[] iArr, double[] dArr) {
        if (dArr != null && dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        if (this.w != null) {
            double exp = 1.0d / (1.0d + Math.exp(-dot(iArr, this.w)));
            if (dArr != null) {
                dArr[0] = 1.0d - exp;
                dArr[1] = exp;
            }
            return exp < 0.5d ? 0 : 1;
        }
        int i = -1;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < this.k; i2++) {
            double dot = dot(iArr, this.W[i2]);
            if (dot > d) {
                d = dot;
                i = i2;
            }
            if (dArr != null) {
                dArr[i2] = dot;
            }
        }
        if (dArr != null) {
            double d2 = 0.0d;
            for (int i3 = 0; i3 < this.k; i3++) {
                dArr[i3] = Math.exp(dArr[i3] - d);
                d2 += dArr[i3];
            }
            for (int i4 = 0; i4 < this.k; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] / d2;
            }
        }
        return i;
    }
}
