/*
 * Decompiled with CFR 0.152.
 */
package umontreal.iro.lecuyer.probdist;

import umontreal.iro.lecuyer.functions.MathFunction;
import umontreal.iro.lecuyer.probdist.BetaDist;
import umontreal.iro.lecuyer.probdist.DiscreteDistributionInt;
import umontreal.iro.lecuyer.util.Num;
import umontreal.iro.lecuyer.util.RootFinder;

public class NegativeBinomialDist
extends DiscreteDistributionInt {
    protected double gamma;
    protected double p;
    private static final double EPSI = 1.0E-10;
    public static double MAXN = 100000.0;

    protected NegativeBinomialDist() {
    }

    public NegativeBinomialDist(double gamma, double p) {
        this.setParams(gamma, p);
    }

    @Override
    public double prob(int x) {
        if (x < 0) {
            return 0.0;
        }
        if (this.p == 0.0) {
            return 0.0;
        }
        if (this.p == 1.0) {
            if (x > 0) {
                return 0.0;
            }
            return 1.0;
        }
        if (this.pdf == null) {
            return NegativeBinomialDist.prob(this.gamma, this.p, x);
        }
        if (x > this.xmax || x < this.xmin) {
            return NegativeBinomialDist.prob(this.gamma, this.p, x);
        }
        return this.pdf[x - this.xmin];
    }

    @Override
    public double cdf(int x) {
        if (x < 0) {
            return 0.0;
        }
        if (this.p >= 1.0) {
            return 1.0;
        }
        if (this.p <= 0.0) {
            return 0.0;
        }
        if (this.cdf != null) {
            if (x >= this.xmax) {
                return 1.0;
            }
            if (x < this.xmin) {
                return NegativeBinomialDist.cdf(this.gamma, this.p, x);
            }
            if (x <= this.xmed) {
                return this.cdf[x - this.xmin];
            }
            return 1.0 - this.cdf[x + 1 - this.xmin];
        }
        return NegativeBinomialDist.cdf(this.gamma, this.p, x);
    }

    @Override
    public double barF(int x) {
        if (x < 1) {
            return 1.0;
        }
        if (this.p >= 1.0) {
            return 0.0;
        }
        if (this.p <= 0.0) {
            return 1.0;
        }
        if (this.cdf == null) {
            return BetaDist.barF(this.gamma, x, 15, this.p);
        }
        if (x > this.xmax) {
            return BetaDist.barF(this.gamma, x, 15, this.p);
        }
        if (x <= this.xmin) {
            return 1.0;
        }
        if (x > this.xmed) {
            return this.cdf[x - this.xmin];
        }
        return 1.0 - this.cdf[x - 1 - this.xmin];
    }

    @Override
    public int inverseFInt(double u) {
        if (this.cdf == null) {
            return NegativeBinomialDist.inverseF(this.gamma, this.p, u);
        }
        return super.inverseFInt(u);
    }

    @Override
    public double getMean() {
        return NegativeBinomialDist.getMean(this.gamma, this.p);
    }

    @Override
    public double getVariance() {
        return NegativeBinomialDist.getVariance(this.gamma, this.p);
    }

    @Override
    public double getStandardDeviation() {
        return NegativeBinomialDist.getStandardDeviation(this.gamma, this.p);
    }

    public static double prob(double gamma, double p, int x) {
        int SLIM = 15;
        double MAXEXP = 709.0895657128241;
        double MINEXP = -708.3964185322641;
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p not in [0, 1]");
        }
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0.0");
        }
        if (x < 0) {
            return 0.0;
        }
        if (p >= 1.0) {
            if (x == 0) {
                return 1.0;
            }
            return 0.0;
        }
        if (p <= 0.0) {
            return 0.0;
        }
        double y = Num.lnGamma(gamma + (double)x) - (Num.lnFactorial(x) + Num.lnGamma(gamma)) + gamma * Math.log(p) + (double)x * Math.log1p(-p);
        if (y >= 709.0895657128241) {
            throw new IllegalArgumentException("term overflow");
        }
        if (y <= -708.3964185322641) {
            return 0.0;
        }
        return Math.exp(y);
    }

    public static double cdf(double gamma, double p, int x) {
        double EPSILON = DiscreteDistributionInt.EPSILON;
        int LIM1 = 100000;
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p not in [0, 1]");
        }
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0.0");
        }
        if (x < 0) {
            return 0.0;
        }
        if (p >= 1.0) {
            return 1.0;
        }
        if (p <= 0.0) {
            return 0.0;
        }
        int mode = 1 + (int)Math.floor((gamma * (1.0 - p) - 1.0) / p);
        if (mode < 0) {
            mode = 0;
        } else if (mode > x) {
            mode = x;
        }
        if (mode <= 100000) {
            int i;
            double termmode;
            double term = termmode = NegativeBinomialDist.prob(gamma, p, mode);
            double sum = termmode;
            for (i = mode; i > 0 && !((term *= (double)i / ((1.0 - p) * (gamma + (double)i - 1.0))) < EPSILON); --i) {
                sum += term;
            }
            term = termmode;
            for (i = mode; i < x && !((term *= (1.0 - p) * (gamma + (double)i) / (double)(i + 1)) < EPSILON); ++i) {
                sum += term;
            }
            if (sum <= 1.0) {
                return sum;
            }
            return 1.0;
        }
        return BetaDist.cdf(gamma, (double)x + 1.0, 15, p);
    }

    public static int inverseF(double gamma, double p, double u) {
        double termmode;
        if (u < 0.0 || u >= 1.0) {
            throw new IllegalArgumentException("u is not in [0,1]");
        }
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p not in [0, 1]");
        }
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        if (p >= 1.0) {
            return 0;
        }
        if (p <= 0.0) {
            return 0;
        }
        if (u <= 0.0) {
            return 0;
        }
        if (u >= 1.0) {
            return Integer.MAX_VALUE;
        }
        int x = 0;
        int mode = 1 + (int)Math.floor((gamma * (1.0 - p) - 1.0) / p);
        if (mode < 0) {
            mode = 0;
        }
        double term = termmode = NegativeBinomialDist.prob(gamma, p, mode);
        double sum = termmode;
        for (int i = mode; i > 0 && !((term *= (double)i / ((1.0 - p) * (gamma + (double)i - 1.0))) < EPSILON); --i) {
            sum += term;
        }
        term = termmode;
        x = mode;
        if (sum < u) {
            while (sum < u && !((term *= (1.0 - p) * (gamma + (double)x) / (double)(x + 1)) < EPSILON)) {
                sum += term;
                ++x;
            }
        } else {
            sum -= term;
            while (sum >= u) {
                --x;
                if (!((term *= (double)x / ((1.0 - p) * (gamma + (double)x - 1.0))) < EPSILON)) {
                    sum -= term;
                    continue;
                }
                break;
            }
        }
        return x;
    }

    public static double[] getMLE(int[] x, int n, double gamma) {
        if (n <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        double mean = 0.0;
        for (int i = 0; i < n; ++i) {
            mean += (double)x[i];
        }
        double[] param = new double[]{gamma / (gamma + (mean /= (double)n))};
        return param;
    }

    public static NegativeBinomialDist getInstanceFromMLE(int[] x, int n, double gamma) {
        double[] parameters = NegativeBinomialDist.getMLE(x, n, gamma);
        return new NegativeBinomialDist(gamma, parameters[0]);
    }

    public static double[] getMLE1(int[] x, int n, double p) {
        if (n <= 0) {
            throw new IllegalArgumentException("n <= 0");
        }
        double mean = 0.0;
        for (int i = 0; i < n; ++i) {
            mean += (double)x[i];
        }
        double gam0 = (mean /= (double)n) * p / (1.0 - p);
        double[] param = new double[1];
        Func1 f = new Func1(p, x, n);
        param[0] = RootFinder.brentDekker(gam0 / 10.0, 10.0 * gam0, f, 1.0E-5);
        return param;
    }

    public static NegativeBinomialDist getInstanceFromMLE1(int[] x, int n, double p) {
        double[] param = NegativeBinomialDist.getMLE1(x, n, p);
        return new NegativeBinomialDist(param[0], p);
    }

    public static double[] getMLE(int[] x, int n) {
        if (n <= 0) {
            throw new IllegalArgumentException("n<= 0");
        }
        double sum = 0.0;
        double max = -2.147483648E9;
        for (int i = 0; i < n; ++i) {
            sum += (double)x[i];
            if (!((double)x[i] > max)) continue;
            max = x[i];
        }
        double mean = sum / (double)n;
        double var = 0.0;
        for (int i = 0; i < n; ++i) {
            var += ((double)x[i] - mean) * ((double)x[i] - mean);
        }
        if (mean >= (var /= (double)n)) {
            throw new UnsupportedOperationException("mean >= variance");
        }
        double estimGamma = mean * mean / (var - mean);
        int[] Fj = new int[(int)max];
        int j = 0;
        while ((double)j < max) {
            int prop = 0;
            for (int i = 0; i < n; ++i) {
                if (x[i] <= j) continue;
                ++prop;
            }
            Fj[j] = prop;
            ++j;
        }
        double[] param = new double[3];
        Function f = new Function(n, (int)max, mean, Fj);
        param[1] = RootFinder.brentDekker(estimGamma / 10.0, estimGamma * 10.0, f, 1.0E-5);
        param[2] = param[1] / (param[1] + mean);
        double[] parameters = new double[]{param[1], param[2]};
        return parameters;
    }

    public static NegativeBinomialDist getInstanceFromMLE(int[] x, int n) {
        double[] parameters = NegativeBinomialDist.getMLE(x, n);
        return new NegativeBinomialDist(parameters[0], parameters[1]);
    }

    public static double getMean(double gamma, double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p not in [0, 1]");
        }
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        return gamma * (1.0 - p) / p;
    }

    public static double getVariance(double gamma, double p) {
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p not in [0, 1]");
        }
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        return gamma * (1.0 - p) / (p * p);
    }

    public static double getStandardDeviation(double gamma, double p) {
        return Math.sqrt(NegativeBinomialDist.getVariance(gamma, p));
    }

    public double getGamma() {
        return this.gamma;
    }

    public double getP() {
        return this.p;
    }

    public void setParams(double gamma, double p) {
        this.supportA = 0;
        if (p < 0.0 || p > 1.0) {
            throw new IllegalArgumentException("p not in [0, 1]");
        }
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        this.gamma = gamma;
        this.p = p;
        int mode = 1 + (int)Math.floor((gamma * (1.0 - p) - 1.0) / p);
        if ((double)mode < 0.0 || (double)mode > MAXN) {
            this.pdf = null;
            this.cdf = null;
            return;
        }
        int Nmax = (int)(gamma * (1.0 - p) / p + 16.0 * Math.sqrt(gamma * (1.0 - p) / (p * p)));
        if (Nmax < 32) {
            Nmax = 32;
        }
        double[] P = new double[1 + Nmax];
        double[] F = new double[1 + Nmax];
        double epsilon = EPSILON / NegativeBinomialDist.prob(gamma, p, mode);
        P[mode] = 1.0;
        double sum = 1.0;
        int i = mode;
        while (i > 0 && P[i] >= epsilon) {
            P[i - 1] = P[i] * (double)i / ((1.0 - p) * (gamma + (double)i - 1.0));
            sum += P[--i];
        }
        int imin = i;
        i = mode;
        while (P[i] >= epsilon) {
            P[i + 1] = P[i] * (1.0 - p) * (gamma + (double)i) / (double)(i + 1);
            sum += P[++i];
            if (i != Nmax - 1) continue;
            double[] nT = new double[1 + (Nmax *= 2)];
            System.arraycopy(P, 0, nT, 0, P.length);
            P = nT;
            nT = new double[1 + Nmax];
            System.arraycopy(F, 0, nT, 0, F.length);
            F = nT;
        }
        int imax = i;
        i = imin;
        while (i <= imax) {
            int n = i++;
            P[n] = P[n] / sum;
        }
        F[imin] = P[imin];
        i = imin;
        while (i < imax && F[i] < 0.5) {
            F[++i] = F[i - 1] + P[i];
        }
        this.xmed = i;
        F[imax] = P[imax];
        i = imax - 1;
        do {
            F[i] = P[i] + F[i + 1];
        } while (--i > this.xmed);
        this.xmin = imin;
        this.xmax = imax;
        this.pdf = new double[imax + 1 - imin];
        this.cdf = new double[imax + 1 - imin];
        System.arraycopy(P, imin, this.pdf, 0, imax + 1 - imin);
        System.arraycopy(F, imin, this.cdf, 0, imax + 1 - imin);
    }

    @Override
    public double[] getParams() {
        double[] retour = new double[]{this.gamma, this.p};
        return retour;
    }

    public String toString() {
        return this.getClass().getSimpleName() + " : gamma = " + this.gamma + ", p = " + this.p;
    }

    private static class Function
    implements MathFunction {
        protected int m;
        protected int max;
        protected double mean;
        protected int[] Fj;

        public Function(int m, int max, double mean, int[] Fj) {
            this.m = m;
            this.max = max;
            this.mean = mean;
            this.Fj = new int[Fj.length];
            System.arraycopy(Fj, 0, this.Fj, 0, Fj.length);
        }

        @Override
        public double evaluate(double s) {
            if (s <= 0.0) {
                return 1.0E100;
            }
            double sum = 0.0;
            double p = s / (s + this.mean);
            for (int j = 0; j < this.max; ++j) {
                sum += (double)this.Fj[j] / (s + (double)j);
            }
            return sum + (double)this.m * Math.log(p);
        }
    }

    private static class Func1
    implements MathFunction {
        protected int n;
        protected int[] x;
        protected double p;

        public Func1(double p, int[] x, int n) {
            this.p = p;
            this.n = n;
            this.x = x;
        }

        @Override
        public double evaluate(double gam) {
            if (gam <= 0.0) {
                return 1.0E100;
            }
            double sum = 0.0;
            for (int j = 0; j < this.n; ++j) {
                sum += Num.digamma(gam + (double)this.x[j]);
            }
            return sum / (double)this.n + Math.log(this.p) - Num.digamma(gam);
        }
    }
}

