package jsat.distributions.multivariate;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.distributions.Gamma;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.SpecialMath;
import jsat.math.optimization.NelderMead;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/distributions/multivariate/SymmetricDirichlet.class */
public class SymmetricDirichlet extends MultivariateDistributionSkeleton {
    private static final long serialVersionUID = -1206894014440494142L;
    private double alpha;
    private int dim;

    public SymmetricDirichlet(double d, int i) {
        setAlpha(d);
        setDimension(i);
    }

    public void setDimension(int i) {
        if (i <= 0) {
            throw new ArithmeticException("A positive number of dimensions must be given");
        }
        this.dim = i;
    }

    public int getDimension() {
        return this.dim;
    }

    public void setAlpha(double d) throws ArithmeticException {
        if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new ArithmeticException("Symmetric Dirichlet Distribution parameters must be positive, " + d + " is invalid");
        }
        this.alpha = d;
    }

    public double getAlpha() {
        return this.alpha;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistributionSkeleton
    public SymmetricDirichlet clone() {
        return new SymmetricDirichlet(this.alpha, this.dim);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistributionSkeleton, jsat.distributions.multivariate.MultivariateDistribution
    public double logPdf(Vec vec) {
        if (vec.length() != this.dim) {
            throw new ArithmeticException(this.dim + " variable distribution can not awnser a " + vec.length() + " dimension variable");
        }
        double d = 0.0d;
        int length = vec.length();
        for (int i = 0; i < length; i++) {
            d += Math.log(vec.get(i)) * (this.alpha - 1.0d);
        }
        double lnGamma = (d + SpecialMath.lnGamma(this.alpha * length)) - (SpecialMath.lnGamma(this.alpha) * length);
        if (Double.isInfinite(lnGamma) || Double.isNaN(lnGamma) || Math.abs(vec.sum() - 1.0d) > 1.0E-14d) {
            return -1.7976931348623157E308d;
        }
        return lnGamma;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(Vec vec) {
        return Math.exp(logPdf(vec));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list, boolean z) {
        Function function = (vec, z2) -> {
            double d = vec.get(0);
            return -(((Double) ParallelUtils.run(z2, list.size(), (i, i2) -> {
                double d2 = 0.0d;
                for (int i = i; i < i2; i++) {
                    Vec vec = (Vec) list.get(i);
                    for (int i2 = 0; i2 < vec.length(); i2++) {
                        d2 += Math.log(vec.get(i2)) * (d - 1.0d);
                    }
                }
                return Double.valueOf(d2);
            }, (d2, d3) -> {
                return Double.valueOf(d2.doubleValue() + d3.doubleValue());
            })).doubleValue() + ((SpecialMath.lnGamma(d * this.dim) - (SpecialMath.lnGamma(d) * this.dim)) * list.size()));
        };
        NelderMead nelderMead = new NelderMead();
        DenseVector denseVector = new DenseVector(1);
        ArrayList arrayList = new ArrayList();
        arrayList.add(denseVector.add(1.0d));
        arrayList.add(denseVector.add(0.1d));
        arrayList.add(denseVector.add(10.0d));
        this.alpha = nelderMead.optimize(1.0E-10d, 100, function, arrayList, z).get(0);
        return true;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public List<Vec> sample(int i, Random random) {
        ArrayList arrayList = new ArrayList(i);
        double[] sample = new Gamma(this.alpha, 1.0d).sample(i * this.dim, random);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            DenseVector denseVector = new DenseVector(this.dim);
            for (int i4 = 0; i4 < this.dim; i4++) {
                int i5 = i2;
                i2++;
                denseVector.set(i4, sample[i5]);
            }
            denseVector.mutableDivide(denseVector.sum());
            arrayList.add(denseVector);
        }
        return arrayList;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1931288699:
                if (implMethodName.equals("lambda$setUsingData$63d08a67$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("jsat/math/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("f") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljsat/linear/Vec;Z)D") && serializedLambda.getImplClass().equals("jsat/distributions/multivariate/SymmetricDirichlet") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljsat/linear/Vec;Z)D")) {
                    SymmetricDirichlet symmetricDirichlet = (SymmetricDirichlet) serializedLambda.getCapturedArg(0);
                    List list = (List) serializedLambda.getCapturedArg(1);
                    return (vec, z2) -> {
                        double d = vec.get(0);
                        return -(((Double) ParallelUtils.run(z2, list.size(), (i, i2) -> {
                            double d2 = 0.0d;
                            for (int i = i; i < i2; i++) {
                                Vec vec = (Vec) list.get(i);
                                for (int i2 = 0; i2 < vec.length(); i2++) {
                                    d2 += Math.log(vec.get(i2)) * (d - 1.0d);
                                }
                            }
                            return Double.valueOf(d2);
                        }, (d2, d3) -> {
                            return Double.valueOf(d2.doubleValue() + d3.doubleValue());
                        })).doubleValue() + ((SpecialMath.lnGamma(d * this.dim) - (SpecialMath.lnGamma(d) * this.dim)) * list.size()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
