/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.distributions.Gamma;
import jsat.distributions.multivariate.MultivariateDistributionSkeleton;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.SpecialMath;
import jsat.math.optimization.NelderMead;
import jsat.utils.concurrent.ParallelUtils;

public class Dirichlet
extends MultivariateDistributionSkeleton {
    private static final long serialVersionUID = 6229508050763067569L;
    private Vec alphas;

    public Dirichlet(Vec alphas) {
        this.setAlphas(alphas);
    }

    public void setAlphas(Vec alphas) throws ArithmeticException {
        for (int i = 0; i < alphas.length(); ++i) {
            double tmp = alphas.get(i);
            if (!(tmp <= 0.0) && !Double.isNaN(tmp) && !Double.isInfinite(tmp)) continue;
            throw new ArithmeticException("Dirichlet Distribution parameters must be positive, " + tmp + " is invalid");
        }
        this.alphas = alphas.clone();
    }

    public Vec getAlphas() {
        return this.alphas;
    }

    @Override
    public Dirichlet clone() {
        return new Dirichlet(this.alphas);
    }

    @Override
    public double logPdf(Vec x) {
        if (x.length() != this.alphas.length()) {
            throw new ArithmeticException(this.alphas.length() + " variable distribution can not awnser a " + x.length() + " dimension variable");
        }
        double logVal = 0.0;
        double sum = 0.0;
        for (int i = 0; i < this.alphas.length(); ++i) {
            double tmp = x.get(i);
            if (tmp <= 0.0) {
                return -1.7976931348623157E308;
            }
            sum += tmp;
            logVal += Math.log(x.get(i)) * (this.alphas.get(i) - 1.0);
        }
        if (Math.abs(sum - 1.0) > 1.0E-14) {
            return -1.7976931348623157E308;
        }
        double logNormalizer = 0.0;
        for (int i = 0; i < this.alphas.length(); ++i) {
            logNormalizer += SpecialMath.lnGamma(this.alphas.get(i));
        }
        return logVal - (logNormalizer -= SpecialMath.lnGamma(this.alphas.sum()));
    }

    @Override
    public double pdf(Vec x) {
        return Math.exp(this.logPdf(x));
    }

    @Override
    public <V extends Vec> boolean setUsingData(List<V> dataSet, boolean parallel) {
        Function logLike = (x, p) -> {
            double constantTerm = SpecialMath.lnGamma(x.sum());
            for (int i = 0; i < x.length(); ++i) {
                constantTerm -= SpecialMath.lnGamma(x.get(i));
            }
            double sum = ParallelUtils.run(p, dataSet.size(), (start, end) -> {
                double local_sum = 0.0;
                for (int i = start; i < end; ++i) {
                    Vec s = (Vec)dataSet.get(i);
                    for (int j = 0; j < x.length(); ++j) {
                        local_sum += Math.log(s.get(j)) * (x.get(j) - 1.0);
                    }
                }
                return local_sum;
            }, (a, b) -> a + b);
            return -(sum + constantTerm * (double)dataSet.size());
        };
        NelderMead optimize = new NelderMead();
        DenseVector guess = new DenseVector(((Vec)dataSet.get(0)).length());
        ArrayList<Vec> guesses = new ArrayList<Vec>();
        guesses.add(guess.add(1.0));
        guesses.add(guess.add(0.1));
        guesses.add(guess.add(10.0));
        this.alphas = optimize.optimize(1.0E-10, 100, logLike, guesses, parallel);
        return true;
    }

    @Override
    public List<Vec> sample(int count, Random rand) {
        int i;
        ArrayList<Vec> samples = new ArrayList<Vec>(count);
        double[][] gammaSamples = new double[this.alphas.length()][];
        for (i = 0; i < gammaSamples.length; ++i) {
            Gamma gamma = new Gamma(this.alphas.get(i), 1.0);
            gammaSamples[i] = gamma.sample(count, rand);
        }
        for (i = 0; i < count; ++i) {
            DenseVector sample = new DenseVector(this.alphas.length());
            for (int j = 0; j < this.alphas.length(); ++j) {
                ((Vec)sample).set(j, gammaSamples[j][i]);
            }
            ((Vec)sample).mutableDivide(((Vec)sample).sum());
            samples.add(sample);
        }
        return samples;
    }
}

