package jsat.distributions.multivariate;

import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.distributions.Gamma;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.SpecialMath;
import jsat.math.optimization.NelderMead;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/distributions/multivariate/Dirichlet.class */
public class Dirichlet extends MultivariateDistributionSkeleton {
    private static final long serialVersionUID = 6229508050763067569L;
    private Vec alphas;

    public Dirichlet(Vec vec) {
        setAlphas(vec);
    }

    public void setAlphas(Vec vec) throws ArithmeticException {
        for (int i = 0; i < vec.length(); i++) {
            double d = vec.get(i);
            if (d <= 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
                throw new ArithmeticException("Dirichlet Distribution parameters must be positive, " + d + " is invalid");
            }
        }
        this.alphas = vec.mo46clone();
    }

    public Vec getAlphas() {
        return this.alphas;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistributionSkeleton
    public Dirichlet clone() {
        return new Dirichlet(this.alphas);
    }

    @Override // jsat.distributions.multivariate.MultivariateDistributionSkeleton, jsat.distributions.multivariate.MultivariateDistribution
    public double logPdf(Vec vec) {
        if (vec.length() != this.alphas.length()) {
            throw new ArithmeticException(this.alphas.length() + " variable distribution can not awnser a " + vec.length() + " dimension variable");
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < this.alphas.length(); i++) {
            double d3 = vec.get(i);
            if (d3 <= 0.0d) {
                return -1.7976931348623157E308d;
            }
            d2 += d3;
            d += Math.log(vec.get(i)) * (this.alphas.get(i) - 1.0d);
        }
        if (Math.abs(d2 - 1.0d) > 1.0E-14d) {
            return -1.7976931348623157E308d;
        }
        double d4 = 0.0d;
        for (int i2 = 0; i2 < this.alphas.length(); i2++) {
            d4 += SpecialMath.lnGamma(this.alphas.get(i2));
        }
        return d - (d4 - SpecialMath.lnGamma(this.alphas.sum()));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(Vec vec) {
        return Math.exp(logPdf(vec));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list, boolean z) {
        Function function = (vec, z2) -> {
            double lnGamma = SpecialMath.lnGamma(vec.sum());
            for (int i = 0; i < vec.length(); i++) {
                lnGamma -= SpecialMath.lnGamma(vec.get(i));
            }
            return -(((Double) ParallelUtils.run(z2, list.size(), (i2, i3) -> {
                double d = 0.0d;
                for (int i2 = i2; i2 < i3; i2++) {
                    Vec vec = (Vec) list.get(i2);
                    for (int i3 = 0; i3 < vec.length(); i3++) {
                        d += Math.log(vec.get(i3)) * (vec.get(i3) - 1.0d);
                    }
                }
                return Double.valueOf(d);
            }, (d, d2) -> {
                return Double.valueOf(d.doubleValue() + d2.doubleValue());
            })).doubleValue() + (lnGamma * list.size()));
        };
        NelderMead nelderMead = new NelderMead();
        DenseVector denseVector = new DenseVector(list.get(0).length());
        ArrayList arrayList = new ArrayList();
        arrayList.add(denseVector.add(1.0d));
        arrayList.add(denseVector.add(0.1d));
        arrayList.add(denseVector.add(10.0d));
        this.alphas = nelderMead.optimize(1.0E-10d, 100, function, arrayList, z);
        return true;
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public List<Vec> sample(int i, Random random) {
        ArrayList arrayList = new ArrayList(i);
        double[] dArr = new double[this.alphas.length()];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            dArr[i2] = new Gamma(this.alphas.get(i2), 1.0d).sample(i, random);
        }
        for (int i3 = 0; i3 < i; i3++) {
            DenseVector denseVector = new DenseVector(this.alphas.length());
            for (int i4 = 0; i4 < this.alphas.length(); i4++) {
                denseVector.set(i4, dArr[i4][i3]);
            }
            denseVector.mutableDivide(denseVector.sum());
            arrayList.add(denseVector);
        }
        return arrayList;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case 1931288699:
                if (implMethodName.equals("lambda$setUsingData$63d08a67$1")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("jsat/math/Function") && serializedLambda.getFunctionalInterfaceMethodName().equals("f") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljsat/linear/Vec;Z)D") && serializedLambda.getImplClass().equals("jsat/distributions/multivariate/Dirichlet") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/List;Ljsat/linear/Vec;Z)D")) {
                    List list = (List) serializedLambda.getCapturedArg(0);
                    return (vec, z2) -> {
                        double lnGamma = SpecialMath.lnGamma(vec.sum());
                        for (int i = 0; i < vec.length(); i++) {
                            lnGamma -= SpecialMath.lnGamma(vec.get(i));
                        }
                        return -(((Double) ParallelUtils.run(z2, list.size(), (i2, i3) -> {
                            double d = 0.0d;
                            for (int i2 = i2; i2 < i3; i2++) {
                                Vec vec = (Vec) list.get(i2);
                                for (int i3 = 0; i3 < vec.length(); i3++) {
                                    d += Math.log(vec.get(i3)) * (vec.get(i3) - 1.0d);
                                }
                            }
                            return Double.valueOf(d);
                        }, (d, d2) -> {
                            return Double.valueOf(d.doubleValue() + d2.doubleValue());
                        })).doubleValue() + (lnGamma * list.size()));
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
