/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class PlattScaling
implements Serializable {
    private static final long serialVersionUID = 1L;
    private double alpha;
    private double beta;
    private static final Logger logger = LoggerFactory.getLogger(PlattScaling.class);

    public PlattScaling(double[] scores, int[] y) {
        this(scores, y, 100);
    }

    public PlattScaling(double[] scores, int[] y, int maxIters) {
        int iter;
        int i;
        int l = scores.length;
        double prior1 = 0.0;
        double prior0 = 0.0;
        for (i = 0; i < l; ++i) {
            if (y[i] > 0) {
                prior1 += 1.0;
                continue;
            }
            prior0 += 1.0;
        }
        double min_step = 1.0E-10;
        double sigma = 1.0E-12;
        double eps = 1.0E-5;
        double hiTarget = (prior1 + 1.0) / (prior1 + 2.0);
        double loTarget = 1.0 / (prior0 + 2.0);
        double[] t = new double[l];
        this.alpha = 0.0;
        this.beta = Math.log((prior0 + 1.0) / (prior1 + 1.0));
        double fval = 0.0;
        for (i = 0; i < l; ++i) {
            t[i] = y[i] > 0 ? hiTarget : loTarget;
            double fApB = scores[i] * this.alpha + this.beta;
            if (fApB >= 0.0) {
                fval += t[i] * fApB + Math.log(1.0 + Math.exp(-fApB));
                continue;
            }
            fval += (t[i] - 1.0) * fApB + Math.log(1.0 + Math.exp(fApB));
        }
        for (iter = 0; iter < maxIters; ++iter) {
            double stepsize;
            double h11 = sigma;
            double h22 = sigma;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            for (i = 0; i < l; ++i) {
                double q;
                double p;
                double fApB = scores[i] * this.alpha + this.beta;
                if (fApB >= 0.0) {
                    p = Math.exp(-fApB) / (1.0 + Math.exp(-fApB));
                    q = 1.0 / (1.0 + Math.exp(-fApB));
                } else {
                    p = 1.0 / (1.0 + Math.exp(fApB));
                    q = Math.exp(fApB) / (1.0 + Math.exp(fApB));
                }
                double d2 = p * q;
                h11 += scores[i] * scores[i] * d2;
                h22 += d2;
                h21 += scores[i] * d2;
                double d1 = t[i] - p;
                g1 += scores[i] * d1;
                g2 += d1;
            }
            if (Math.abs(g1) < eps && Math.abs(g2) < eps) break;
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;
            for (stepsize = 1.0; stepsize >= min_step; stepsize /= 2.0) {
                double newA = this.alpha + stepsize * dA;
                double newB = this.beta + stepsize * dB;
                double newf = 0.0;
                for (i = 0; i < l; ++i) {
                    double fApB = scores[i] * newA + newB;
                    if (fApB >= 0.0) {
                        newf += t[i] * fApB + Math.log(1.0 + Math.exp(-fApB));
                        continue;
                    }
                    newf += (t[i] - 1.0) * fApB + Math.log(1.0 + Math.exp(fApB));
                }
                if (!(newf < fval + 1.0E-4 * stepsize * gd)) continue;
                this.alpha = newA;
                this.beta = newB;
                fval = newf;
                break;
            }
            if (!(stepsize < min_step)) continue;
            logger.error("Line search fails.");
            break;
        }
        if (iter >= maxIters) {
            logger.warn("Reaches maximal iterations");
        }
    }

    public double predict(double y) {
        double fApB = y * this.alpha + this.beta;
        if (fApB >= 0.0) {
            return Math.exp(-fApB) / (1.0 + Math.exp(-fApB));
        }
        return 1.0 / (1.0 + Math.exp(fApB));
    }

    public static void multiclass(int k, double[][] r, double[] p) {
        int iter;
        double[][] Q = new double[k][k];
        double[] Qp = new double[k];
        double eps = 0.005 / (double)k;
        for (int t = 0; t < k; ++t) {
            int j;
            p[t] = 1.0 / (double)k;
            Q[t][t] = 0.0;
            for (j = 0; j < t; ++j) {
                double[] dArray = Q[t];
                int n = t;
                dArray[n] = dArray[n] + r[j][t] * r[j][t];
                Q[t][j] = Q[j][t];
            }
            for (j = t + 1; j < k; ++j) {
                double[] dArray = Q[t];
                int n = t;
                dArray[n] = dArray[n] + r[j][t] * r[j][t];
                Q[t][j] = -r[j][t] * r[t][j];
            }
        }
        int maxIter = Math.max(100, k);
        for (iter = 0; iter < maxIter; ++iter) {
            int t;
            double pQp = 0.0;
            for (int t2 = 0; t2 < k; ++t2) {
                Qp[t2] = 0.0;
                for (int j = 0; j < k; ++j) {
                    int n = t2;
                    Qp[n] = Qp[n] + Q[t2][j] * p[j];
                }
                pQp += p[t2] * Qp[t2];
            }
            double max_error = 0.0;
            for (t = 0; t < k; ++t) {
                double error = Math.abs(Qp[t] - pQp);
                if (!(error > max_error)) continue;
                max_error = error;
            }
            if (max_error < eps) break;
            for (t = 0; t < k; ++t) {
                double diff = (-Qp[t] + pQp) / Q[t][t];
                int n = t;
                p[n] = p[n] + diff;
                pQp = (pQp + diff * (diff * Q[t][t] + 2.0 * Qp[t])) / (1.0 + diff) / (1.0 + diff);
                int j = 0;
                while (j < k) {
                    Qp[j] = (Qp[j] + diff * Q[t][j]) / (1.0 + diff);
                    int n2 = j++;
                    p[n2] = p[n2] / (1.0 + diff);
                }
            }
        }
        if (iter >= maxIter) {
            logger.warn("Reaches maximal iterations");
        }
    }
}

