/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions;

import jsat.distributions.ContinuousDistribution;
import jsat.linear.Vec;

public class Normal
extends ContinuousDistribution {
    private static final long serialVersionUID = -5298346576152986165L;
    private double mean;
    private double stndDev;

    public Normal() {
        this(0.0, 1.0);
    }

    public Normal(double mean, double stndDev) {
        if (stndDev <= 0.0) {
            throw new RuntimeException("Standerd deviation of the normal distribution needs to be greater than zero");
        }
        this.setMean(mean);
        this.setStndDev(stndDev);
    }

    public void setMean(double mean) {
        if (Double.isInfinite(mean) || Double.isNaN(mean)) {
            throw new ArithmeticException("Mean can not be infinite of NaN");
        }
        this.mean = mean;
    }

    public void setStndDev(double stndDev) {
        if (Double.isInfinite(stndDev) || Double.isNaN(stndDev)) {
            throw new ArithmeticException("Standard devation can not be infinite of NaN");
        }
        if (stndDev <= 0.0) {
            throw new ArithmeticException("The standard devation can not be <= 0");
        }
        this.stndDev = stndDev;
    }

    public static double cdf(double x, double mu, double sigma) {
        if (Double.isNaN(x) || Double.isInfinite(x)) {
            throw new ArithmeticException("X is not a real number");
        }
        return Normal.cdfApproxMarsaglia2004(Normal.zTransform(x, mu, sigma));
    }

    @Override
    public double cdf(double x) {
        return Normal.cdf(x, this.mean, this.stndDev);
    }

    public static double invcdf(double x, double mu, double sigma) {
        double result;
        double q;
        if (x < 0.0 || x > 1.0) {
            throw new RuntimeException("Inverse of a probability requires a probablity in the range [0,1], not " + x);
        }
        double[] a = new double[]{-39.69683028665376, 220.9460984245205, -275.9285104469687, 138.357751867269, -30.66479806614716, 2.506628277459239};
        double[] b = new double[]{-54.47609879822406, 161.5858368580409, -155.6989798598866, 66.80131188771972, -13.28068155288572};
        double[] c = new double[]{-0.007784894002430293, -0.3223964580411365, -2.400758277161838, -2.549732539343734, 4.374664141464968, 2.938163982698783};
        double[] d = new double[]{0.007784695709041462, 0.3224671290700398, 2.445134137142996, 3.754408661907416};
        double p_low = 0.02425;
        double p_high = 1.0 - p_low;
        double p = x;
        if (0.0 < p && p < p_low) {
            q = Math.sqrt(-2.0 * Math.log(p));
            result = (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5]) / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0);
        } else if (p_low <= p && p <= p_high) {
            q = p - 0.5;
            double r = q * q;
            result = (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0);
        } else {
            q = Math.sqrt(-2.0 * Math.log(1.0 - p));
            result = -(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5]) / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0);
        }
        double e = Normal.cdf(result, 0.0, 1.0) - p;
        double u = e * Math.sqrt(Math.PI * 2) * Math.exp(result * result / 2.0);
        result -= u / (1.0 + result * u / 2.0);
        return result * sigma + mu;
    }

    @Override
    public double invCdf(double d) {
        return Normal.invcdf(d, this.mean, this.stndDev);
    }

    public static double pdf(double x, double mu, double sigma) {
        return 1.0 / Math.sqrt(Math.PI * 2 * sigma * sigma) * Math.exp(-Math.pow(x - mu, 2.0) / (2.0 * sigma * sigma));
    }

    @Override
    public double pdf(double d) {
        return Normal.pdf(d, this.mean, this.stndDev);
    }

    public static double logPdf(double x, double mu, double sigma) {
        return -0.5 * Math.log(Math.PI * 2) - Math.log(sigma) + -Math.pow(x - mu, 2.0) / (2.0 * sigma * sigma);
    }

    @Override
    public double logPdf(double x) {
        return Normal.logPdf(x, this.mean, this.stndDev);
    }

    public double invPdf(double d) {
        return Math.exp(Math.pow(this.mean - d, 2.0) / (2.0 * Math.pow(this.stndDev, 2.0))) * Math.sqrt(Math.PI * 2) * this.stndDev;
    }

    public static double zTransform(double x, double mu, double sigma) {
        return (x - mu) / sigma;
    }

    public double zTransform(double x) {
        return Normal.zTransform(x, this.mean, this.stndDev);
    }

    private static double cdfApproxMarsaglia2004(double x) {
        if (x >= 8.22) {
            return 1.0;
        }
        if (x <= -8.22) {
            return 0.0;
        }
        double s = x;
        double t = 0.0;
        double b = x;
        double q = x * x;
        double i = 1.0;
        while (s != t) {
            t = s;
            s = t + (b *= q / (i += 2.0));
        }
        return 0.5 + s * Math.exp(-0.5 * q - 0.9189385332046728);
    }

    @Override
    public String getDescriptiveName() {
        return "Normal(\u03bc=" + this.mean + ", \u03c3=" + this.stndDev + ")";
    }

    @Override
    public double min() {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double max() {
        return Double.POSITIVE_INFINITY;
    }

    @Override
    public String getDistributionName() {
        return "Normal";
    }

    @Override
    public String[] getVariables() {
        return new String[]{"\u03bc", "\u03c3"};
    }

    @Override
    public void setVariable(String var, double value) {
        if (var.equals("\u03bc")) {
            this.mean = value;
        } else if (var.equals("\u03c3")) {
            this.setStndDev(value);
        }
    }

    @Override
    public ContinuousDistribution clone() {
        return new Normal(this.mean, this.stndDev);
    }

    @Override
    public void setUsingData(Vec data) {
        this.mean = data.mean();
        this.setStndDev(data.standardDeviation());
    }

    @Override
    public double[] getCurrentVariableValues() {
        return new double[]{this.mean, this.stndDev};
    }

    @Override
    public double mean() {
        return this.mean;
    }

    @Override
    public double median() {
        return this.mean;
    }

    @Override
    public double mode() {
        return this.mean;
    }

    @Override
    public double variance() {
        return this.stndDev * this.stndDev;
    }

    @Override
    public double standardDeviation() {
        return this.stndDev;
    }

    @Override
    public double skewness() {
        return 0.0;
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        long temp = Double.doubleToLongBits(this.mean);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        temp = Double.doubleToLongBits(this.stndDev);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        return result;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (this.getClass() != obj.getClass()) {
            return false;
        }
        Normal other = (Normal)obj;
        if (Double.doubleToLongBits(this.mean) != Double.doubleToLongBits(other.mean)) {
            return false;
        }
        return Double.doubleToLongBits(this.stndDev) == Double.doubleToLongBits(other.stndDev);
    }
}

