/*
 * Decompiled with CFR 0.152.
 */
package com.datumbox.framework.core.machinelearning.clustering;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.DataframeMatrix;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.algorithms.AbstractDPMM;
import java.util.Map;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.OpenMapRealMatrix;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public class GaussianDPMM
extends AbstractDPMM<Cluster, ModelParameters, TrainingParameters> {
    protected GaussianDPMM(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected GaussianDPMM(String storageName, Configuration configuration) {
        super(storageName, configuration);
    }

    @Override
    protected Cluster createNewCluster(Integer clusterId) {
        ModelParameters modelParameters = (ModelParameters)this.knowledgeBase.getModelParameters();
        TrainingParameters trainingParameters = (TrainingParameters)this.knowledgeBase.getTrainingParameters();
        Cluster c = new Cluster(clusterId, modelParameters.getD(), trainingParameters.getKappa0(), trainingParameters.getNu0(), trainingParameters.getMu0(), trainingParameters.getPsi0());
        c.setFeatureIds(modelParameters.getFeatureIds());
        return c;
    }

    public static class TrainingParameters
    extends AbstractDPMM.AbstractTrainingParameters {
        private static final long serialVersionUID = 2L;
        private int kappa0 = 0;
        private int nu0 = 1;
        private RealVector mu0;
        private RealMatrix psi0;

        public int getKappa0() {
            return this.kappa0;
        }

        public void setKappa0(int kappa0) {
            this.kappa0 = kappa0;
        }

        public int getNu0() {
            return this.nu0;
        }

        public void setNu0(int nu0) {
            this.nu0 = nu0;
        }

        public RealVector getMu0() {
            return this.mu0;
        }

        public void setMu0(RealVector mu0) {
            this.mu0 = mu0;
        }

        public RealMatrix getPsi0() {
            return this.psi0;
        }

        public void setPsi0(RealMatrix psi0) {
            this.psi0 = psi0;
        }
    }

    public static class ModelParameters
    extends AbstractDPMM.AbstractModelParameters<Cluster> {
        private static final long serialVersionUID = 1L;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }
    }

    public static class Cluster
    extends AbstractDPMM.AbstractCluster {
        private static final long serialVersionUID = 2L;
        private final int dimensions;
        private final int kappa0;
        private final int nu0;
        private final RealVector mu0;
        private final RealMatrix psi0;
        private RealVector mean;
        private RealMatrix covariance;
        private RealMatrix meanError;
        private int meanDf;
        private RealVector xi_sum;
        private RealMatrix xi_square_sum;
        private volatile Double cache_covariance_determinant;
        private volatile Array2DRowRealMatrix cache_covariance_inverse;

        protected Cluster(Integer clusterId, int dimensions, int kappa0, int nu0, RealVector mu0, RealMatrix psi0) {
            super(clusterId);
            if (nu0 < dimensions) {
                nu0 = dimensions;
            }
            this.mean = new OpenMapRealVector(dimensions);
            this.covariance = new OpenMapRealMatrix(dimensions, dimensions);
            for (int i = 0; i < dimensions; ++i) {
                this.covariance.setEntry(i, i, 1.0);
            }
            this.meanError = this.calculateMeanError(psi0, kappa0, nu0);
            this.meanDf = nu0 - dimensions + 1;
            this.kappa0 = kappa0;
            this.nu0 = nu0;
            this.mu0 = new OpenMapRealVector(mu0);
            this.psi0 = new OpenMapRealMatrix(dimensions, dimensions).add(psi0);
            this.dimensions = dimensions;
            this.xi_sum = new OpenMapRealVector(dimensions);
            this.xi_square_sum = new OpenMapRealMatrix(dimensions, dimensions);
            this.cache_covariance_determinant = null;
            this.cache_covariance_inverse = null;
        }

        private void assertModifiable() {
            if (this.xi_sum == null || this.xi_square_sum == null) {
                throw new RuntimeException("The cluster parameters are already estimated.");
            }
        }

        @Override
        protected Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        @Override
        protected void setFeatureIds(Map<Object, Integer> featureIds) {
            this.featureIds = featureIds;
        }

        protected RealMatrix getMeanError() {
            return this.meanError;
        }

        protected int getMeanDf() {
            return this.meanDf;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        protected double posteriorLogPdf(Record r) {
            RealVector x_mu = DataframeMatrix.parseRecord(r, this.featureIds);
            x_mu = x_mu.subtract(this.mean);
            if (this.cache_covariance_determinant == null || this.cache_covariance_inverse == null) {
                Cluster cluster = this;
                synchronized (cluster) {
                    if (this.cache_covariance_determinant == null || this.cache_covariance_inverse == null) {
                        LUDecomposition lud = new LUDecomposition(this.covariance);
                        this.cache_covariance_determinant = lud.getDeterminant();
                        this.cache_covariance_inverse = (Array2DRowRealMatrix)lud.getSolver().getInverse();
                    }
                }
            }
            double x_muInvSx_muT = this.cache_covariance_inverse.preMultiply(x_mu).dotProduct(x_mu);
            double normConst = 1.0 / (Math.pow(Math.PI * 2, (double)this.dimensions / 2.0) * Math.pow(this.cache_covariance_determinant, 0.5));
            double logPdf = -0.5 * x_muInvSx_muT + Math.log(normConst);
            return logPdf;
        }

        @Override
        protected void add(Record r) {
            this.assertModifiable();
            RealVector rv = DataframeMatrix.parseRecord(r, this.featureIds);
            this.xi_sum = this.xi_sum.add(rv);
            this.xi_square_sum = this.xi_square_sum.add(rv.outerProduct(rv));
            ++this.size;
            this.updateClusterParameters();
        }

        @Override
        protected void remove(Record r) {
            this.assertModifiable();
            if (this.size == 0) {
                throw new IllegalArgumentException("The cluster is empty.");
            }
            --this.size;
            RealVector rv = DataframeMatrix.parseRecord(r, this.featureIds);
            this.xi_sum = this.xi_sum.subtract(rv);
            this.xi_square_sum = this.xi_square_sum.subtract(rv.outerProduct(rv));
            this.updateClusterParameters();
        }

        private RealMatrix calculateMeanError(RealMatrix Psi, int kappa, int nu) {
            return Psi.scalarMultiply(1.0 / ((double)kappa * ((double)(nu - this.dimensions) + 1.0)));
        }

        @Override
        protected void clear() {
            this.xi_sum = null;
            this.xi_square_sum = null;
            this.cache_covariance_determinant = null;
            this.cache_covariance_inverse = null;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        protected void updateClusterParameters() {
            this.assertModifiable();
            int kappa_n = this.kappa0 + this.size;
            int nu = this.nu0 + this.size;
            RealVector mu = this.xi_sum.mapDivide((double)this.size);
            RealVector mu_mu0 = mu.subtract(this.mu0);
            RealMatrix C = this.xi_square_sum.subtract(mu.outerProduct(mu).scalarMultiply((double)this.size));
            RealMatrix psi = this.psi0.add(C.add(mu_mu0.outerProduct(mu_mu0).scalarMultiply((double)(this.kappa0 * this.size) / (double)kappa_n)));
            this.mean = this.mu0.mapMultiply((double)this.kappa0).add(mu.mapMultiply((double)this.size)).mapDivide((double)kappa_n);
            Cluster cluster = this;
            synchronized (cluster) {
                this.covariance = psi.scalarMultiply(((double)kappa_n + 1.0) / ((double)kappa_n * ((double)(nu - this.dimensions) + 1.0)));
                this.cache_covariance_determinant = null;
                this.cache_covariance_inverse = null;
            }
            this.meanError = this.calculateMeanError(psi, kappa_n, nu);
            this.meanDf = nu - this.dimensions + 1;
        }
    }
}

