/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.empirical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import jsat.distributions.ContinuousDistribution;
import jsat.distributions.empirical.kernelfunc.EpanechnikovKF;
import jsat.distributions.empirical.kernelfunc.GaussKF;
import jsat.distributions.empirical.kernelfunc.KernelFunction;
import jsat.distributions.empirical.kernelfunc.UniformKF;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.utils.ProbailityMatch;

public class KernelDensityEstimator
extends ContinuousDistribution {
    private static final long serialVersionUID = 7708020456632603947L;
    private double[] X;
    private double[] weights;
    private double sumOFWeights;
    private double h;
    private double Xmean;
    private double Xvar;
    private double Xskew;
    private KernelFunction k;

    public static double BandwithGuassEstimate(Vec X) {
        if (X.length() == 1) {
            return 1.0;
        }
        if (X.standardDeviation() == 0.0) {
            return 1.06 * Math.pow(X.length(), -0.2);
        }
        return 1.06 * X.standardDeviation() * Math.pow(X.length(), -0.2);
    }

    public static KernelFunction autoKernel(Vec dataPoints) {
        if (dataPoints.length() < 30) {
            return GaussKF.getInstance();
        }
        if (dataPoints.length() < 1000) {
            return EpanechnikovKF.getInstance();
        }
        return UniformKF.getInstance();
    }

    public KernelDensityEstimator(Vec dataPoints) {
        this(dataPoints, KernelDensityEstimator.autoKernel(dataPoints));
    }

    public KernelDensityEstimator(Vec dataPoints, KernelFunction k) {
        this(dataPoints, k, KernelDensityEstimator.BandwithGuassEstimate(dataPoints));
    }

    public KernelDensityEstimator(Vec dataPoints, KernelFunction k, double[] weights) {
        this(dataPoints, k, KernelDensityEstimator.BandwithGuassEstimate(dataPoints), weights);
    }

    public KernelDensityEstimator(Vec dataPoints, KernelFunction k, double h) {
        this.setUpX(dataPoints);
        this.k = k;
        this.h = h;
    }

    public KernelDensityEstimator(Vec dataPoints, KernelFunction k, double h, double[] weights) {
        this.setUpX(dataPoints, weights);
        this.k = k;
        this.h = h;
    }

    private KernelDensityEstimator(double[] X, double h, double Xmean, double Xvar, double Xskew, KernelFunction k, double sumOfWeights, double[] weights) {
        this.X = Arrays.copyOf(X, X.length);
        this.h = h;
        this.Xmean = Xmean;
        this.Xvar = Xvar;
        this.Xskew = Xskew;
        this.k = k;
        this.sumOFWeights = sumOfWeights;
        this.weights = Arrays.copyOf(weights, weights.length);
    }

    private void setUpX(Vec S) {
        this.Xmean = S.mean();
        this.Xvar = S.variance();
        this.Xskew = S.skewness();
        this.X = S.arrayCopy();
        Arrays.sort(this.X);
        this.sumOFWeights = this.X.length;
        this.weights = new double[0];
    }

    private void setUpX(Vec S, double[] weights) {
        int i;
        if (S.length() != weights.length) {
            throw new RuntimeException("Weights and variables do not have the same length");
        }
        OnLineStatistics stats = new OnLineStatistics();
        this.X = new double[S.length()];
        this.weights = Arrays.copyOf(weights, S.length());
        ArrayList<ProbailityMatch<Double>> sorter = new ArrayList<ProbailityMatch<Double>>(S.length());
        for (i = 0; i < S.length(); ++i) {
            sorter.add(new ProbailityMatch<Double>(S.get(i), weights[i]));
        }
        Collections.sort(sorter);
        for (i = 0; i < sorter.size(); ++i) {
            this.X[i] = ((ProbailityMatch)sorter.get(i)).getProbability();
            this.weights[i] = (Double)((ProbailityMatch)sorter.get(i)).getMatch();
            stats.add(this.X[i], this.weights[i]);
        }
        for (i = 1; i < this.weights.length; ++i) {
            int n = i;
            this.weights[n] = this.weights[n] + this.weights[i - 1];
        }
        this.sumOFWeights = this.weights[this.weights.length - 1];
        this.Xmean = stats.getMean();
        this.Xvar = stats.getVarance();
        this.Xskew = stats.getSkewness();
    }

    private double getWeight(int i) {
        if (this.weights.length == 0) {
            return 1.0;
        }
        if (i == 0) {
            return this.weights[i];
        }
        return this.weights[i] - this.weights[i - 1];
    }

    @Override
    public double pdf(double x) {
        return this.pdf(x, -1);
    }

    private double pdf(double x, int j) {
        int from = Arrays.binarySearch(this.X, x - this.h * this.k.cutOff());
        int to = Arrays.binarySearch(this.X, x + this.h * this.k.cutOff());
        from = from < 0 ? -from - 1 : from;
        int n = to = to < 0 ? -to - 1 : to;
        if (this.weights.length == 0 && this.k instanceof UniformKF) {
            return (double)(to - from) * 0.5 / (this.sumOFWeights * this.h);
        }
        double sum = 0.0;
        for (int i = Math.max(0, from); i < Math.min(this.X.length, to + 1); ++i) {
            if (i == j) continue;
            sum += this.k.k((x - this.X[i]) / this.h) * this.getWeight(i);
        }
        return sum / (this.sumOFWeights * this.h);
    }

    @Override
    public double cdf(double x) {
        int from = Arrays.binarySearch(this.X, x - this.h * this.k.cutOff());
        int to = Arrays.binarySearch(this.X, x + this.h * this.k.cutOff());
        from = from < 0 ? -from - 1 : from;
        to = to < 0 ? -to - 1 : to;
        double sum = 0.0;
        for (int i = Math.max(0, from); i < Math.min(this.X.length, to + 1); ++i) {
            sum += this.k.intK((x - this.X[i]) / this.h) * this.getWeight(i);
        }
        sum = this.weights.length == 0 ? (sum += (double)Math.max(0, from)) : (sum += this.weights[from]);
        return sum / (double)this.X.length;
    }

    @Override
    public double invCdf(double p) {
        double kd0;
        int index;
        if (this.weights.length == 0) {
            double r = p * (double)this.X.length;
            index = (int)r;
            double pd0 = r - (double)index;
            double pd1 = 1.0 - pd0;
            kd0 = this.k.intK(pd1);
        } else {
            double XEstimate = p * this.sumOFWeights;
            index = Arrays.binarySearch(this.weights, XEstimate);
            index = index < 0 ? -index - 1 : index;
            kd0 = this.X[index] != 0.0 ? 1.0 : 1.0;
        }
        if (index == this.X.length - 1) {
            return this.X[index] * kd0;
        }
        double x = this.X[index] * kd0 + this.X[index + 1] * (1.0 - kd0);
        return x;
    }

    @Override
    public double min() {
        return this.X[0] - this.h;
    }

    @Override
    public double max() {
        return this.X[this.X.length - 1] + this.h;
    }

    @Override
    public String getDistributionName() {
        return "Kernel Density Estimate";
    }

    @Override
    public String[] getVariables() {
        return new String[]{"h"};
    }

    @Override
    public double[] getCurrentVariableValues() {
        return new double[]{this.h};
    }

    public void setBandwith(double val) {
        if (val <= 0.0 || Double.isInfinite(val)) {
            throw new ArithmeticException("Bandwith parameter h must be greater than zero, not 0");
        }
        this.h = val;
    }

    public double getBandwith() {
        return this.h;
    }

    @Override
    public void setVariable(String var, double value) {
        if (var.equals("h")) {
            this.setBandwith(value);
        }
    }

    @Override
    public KernelDensityEstimator clone() {
        return new KernelDensityEstimator(this.X, this.h, this.Xmean, this.Xvar, this.Xskew, this.k, this.sumOFWeights, this.weights);
    }

    @Override
    public void setUsingData(Vec data) {
        this.setUpX(data);
        this.h = KernelDensityEstimator.BandwithGuassEstimate(data);
    }

    @Override
    public double mean() {
        return this.Xmean;
    }

    @Override
    public double mode() {
        double maxP = 0.0;
        double maxV = Double.NaN;
        for (int i = 0; i < this.X.length; ++i) {
            double d;
            double pTmp = this.pdf(this.X[i]);
            if (!(d > maxP)) continue;
            maxP = pTmp;
            maxV = this.X[i];
        }
        return maxV;
    }

    @Override
    public double variance() {
        return this.Xvar + this.h * this.h * this.k.k2();
    }

    @Override
    public double skewness() {
        return this.Xskew;
    }

    public int hashCode() {
        int prime = 31;
        int result = 1;
        result = 31 * result + Arrays.hashCode(this.X);
        long temp = Double.doubleToLongBits(this.Xmean);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        temp = Double.doubleToLongBits(this.Xskew);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        temp = Double.doubleToLongBits(this.Xvar);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        temp = Double.doubleToLongBits(this.h);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        result = 31 * result + (this.k == null ? 0 : this.k.hashCode());
        temp = Double.doubleToLongBits(this.sumOFWeights);
        result = 31 * result + (int)(temp ^ temp >>> 32);
        result = 31 * result + Arrays.hashCode(this.weights);
        return result;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null) {
            return false;
        }
        if (!(obj instanceof KernelDensityEstimator)) {
            return false;
        }
        KernelDensityEstimator other = (KernelDensityEstimator)obj;
        if (Double.doubleToLongBits(this.Xmean) != Double.doubleToLongBits(other.Xmean)) {
            return false;
        }
        if (Double.doubleToLongBits(this.Xskew) != Double.doubleToLongBits(other.Xskew)) {
            return false;
        }
        if (Double.doubleToLongBits(this.Xvar) != Double.doubleToLongBits(other.Xvar)) {
            return false;
        }
        if (Double.doubleToLongBits(this.h) != Double.doubleToLongBits(other.h)) {
            return false;
        }
        if (Double.doubleToLongBits(this.sumOFWeights) != Double.doubleToLongBits(other.sumOFWeights)) {
            return false;
        }
        if (this.k == null ? other.k != null : this.k.getClass() != other.k.getClass()) {
            return false;
        }
        if (!Arrays.equals(this.X, other.X)) {
            return false;
        }
        return Arrays.equals(this.weights, other.weights);
    }
}

