/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions;

import java.util.Random;
import jsat.distributions.ContinuousDistribution;
import jsat.distributions.Gamma;
import jsat.linear.Vec;
import jsat.math.SpecialMath;

public class Beta
extends ContinuousDistribution {
    private static final long serialVersionUID = 8001402067928143972L;
    double alpha;
    double beta;

    public Beta(double alpha, double beta) {
        if (alpha <= 0.0) {
            throw new ArithmeticException("Alpha must be > 0, not " + alpha);
        }
        if (beta <= 0.0) {
            throw new ArithmeticException("Beta must be > 0, not " + beta);
        }
        this.alpha = alpha;
        this.beta = beta;
    }

    @Override
    public double logPdf(double x) {
        if (x <= 0.0 || x >= 1.0) {
            return -1.7976931348623157E308;
        }
        return (this.alpha - 1.0) * Math.log(x) + (this.beta - 1.0) * Math.log(1.0 - x) - SpecialMath.lnBeta(this.alpha, this.beta);
    }

    @Override
    public double pdf(double x) {
        if (x <= 0.0) {
            return 0.0;
        }
        if (x >= 1.0) {
            return 0.0;
        }
        return Math.exp(this.logPdf(x));
    }

    @Override
    public double cdf(double x) {
        if (x <= 0.0) {
            return 0.0;
        }
        if (x >= 1.0) {
            return 1.0;
        }
        return SpecialMath.betaIncReg(x, this.alpha, this.beta);
    }

    @Override
    public double invCdf(double p) {
        if (p < 0.0 || p > 1.0) {
            throw new ArithmeticException("p must be in the range [0,1], not " + p);
        }
        return SpecialMath.invBetaIncReg(p, this.alpha, this.beta);
    }

    @Override
    public double min() {
        return 0.0;
    }

    @Override
    public double max() {
        return 1.0;
    }

    @Override
    public String getDistributionName() {
        return "Beta";
    }

    @Override
    public String[] getVariables() {
        return new String[]{"alpha", "beta"};
    }

    @Override
    public double[] getCurrentVariableValues() {
        return new double[]{this.alpha, this.beta};
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void setVariable(String var, double value) {
        if (var.equals("alpha")) {
            if (!(value > 0.0)) throw new RuntimeException("Alpha must be > 0, not " + value);
            this.alpha = value;
            return;
        } else {
            if (!var.equals("beta")) return;
            if (!(value > 0.0)) throw new RuntimeException("Beta must be > 0, not " + value);
            this.beta = value;
        }
    }

    @Override
    public ContinuousDistribution clone() {
        return new Beta(this.alpha, this.beta);
    }

    @Override
    public void setUsingData(Vec data) {
        double mean = data.mean();
        double var = data.variance();
        this.alpha = (mean * mean - mean * mean * mean - mean * var) / var;
        this.beta = (this.alpha - this.alpha * mean) / mean;
    }

    @Override
    public double mean() {
        return this.alpha / (this.alpha + this.beta);
    }

    @Override
    public double median() {
        return SpecialMath.invBetaIncReg(0.5, this.alpha, this.beta);
    }

    @Override
    public double mode() {
        if (this.alpha > 1.0 && this.beta > 1.0) {
            return (this.alpha - 1.0) / (this.alpha + this.beta - 2.0);
        }
        return Double.NaN;
    }

    @Override
    public double variance() {
        return this.alpha * this.beta / (Math.pow(this.alpha + this.beta, 2.0) * (this.alpha + this.beta + 1.0));
    }

    @Override
    public double skewness() {
        return 2.0 * (this.beta - this.alpha) * Math.sqrt(this.alpha + this.beta + 1.0) / ((this.alpha + this.beta + 2.0) * Math.sqrt(this.alpha * this.beta));
    }

    @Override
    public double[] sample(int numSamples, Random rand) {
        double[] a = new Gamma(this.alpha, 1.0).sample(numSamples, rand);
        double[] b = new Gamma(this.beta, 1.0).sample(numSamples, rand);
        for (int i = 0; i < a.length; ++i) {
            int n = i;
            a[n] = a[n] / (a[i] + b[i]);
        }
        return a;
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        long temp = Double.doubleToLongBits(this.alpha);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        temp = Double.doubleToLongBits(this.beta);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        return result;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        Beta other = (Beta)obj;
        if (Double.doubleToLongBits(this.alpha) != Double.doubleToLongBits(other.alpha)) {
            return false;
        }
        return Double.doubleToLongBits(this.beta) == Double.doubleToLongBits(other.beta);
    }
}

