package com.datumbox.framework.core.machinelearning.featureselection;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.TypeInference;
import com.datumbox.framework.common.storage.interfaces.BigMap;
import com.datumbox.framework.common.storage.interfaces.StorageEngine;
import com.datumbox.framework.core.common.dataobjects.Dataframe;
import com.datumbox.framework.core.common.dataobjects.DataframeMatrix;
import com.datumbox.framework.core.common.dataobjects.Record;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.OpenMapRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.FastMath;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/PCA.class */
public class PCA extends AbstractFeatureSelector<ModelParameters, TrainingParameters> {

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/PCA$ModelParameters.class */
    public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 2;

        @BigMap(keyClass = Object.class, valueClass = Integer.class, mapType = StorageEngine.MapType.HASHMAP, storageHint = StorageEngine.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Integer> featureIds;
        private RealVector mean;
        private RealVector eigenValues;
        private RealMatrix components;

        protected ModelParameters(StorageEngine storageEngine) {
            super(storageEngine);
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        public RealVector getMean() {
            return this.mean;
        }

        protected void setMean(RealVector realVector) {
            this.mean = realVector;
        }

        public RealVector getEigenValues() {
            return this.eigenValues;
        }

        protected void setEigenValues(RealVector realVector) {
            this.eigenValues = realVector;
        }

        public RealMatrix getComponents() {
            return this.components;
        }

        protected void setComponents(RealMatrix realMatrix) {
            this.components = realMatrix;
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/PCA$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private boolean whitened = false;
        private Integer maxDimensions = null;
        private Double variancePercentageThreshold = null;

        public boolean isWhitened() {
            return this.whitened;
        }

        public void setWhitened(boolean z) {
            this.whitened = z;
        }

        public Integer getMaxDimensions() {
            return this.maxDimensions;
        }

        public void setMaxDimensions(Integer num) {
            this.maxDimensions = num;
        }

        public Double getVariancePercentageThreshold() {
            return this.variancePercentageThreshold;
        }

        public void setVariancePercentageThreshold(Double d) {
            this.variancePercentageThreshold = d;
        }
    }

    protected PCA(TrainingParameters trainingParameters, Configuration configuration) {
        super(trainingParameters, configuration);
    }

    protected PCA(String str, Configuration configuration) {
        super(str, configuration);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector, com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer, com.datumbox.framework.core.machinelearning.common.interfaces.Trainable
    public void fit(Dataframe dataframe) {
        Set<TypeInference.DataType> supportedXDataTypes = getSupportedXDataTypes();
        Iterator<TypeInference.DataType> it = dataframe.getXDataTypes().values().iterator();
        while (it.hasNext()) {
            if (!supportedXDataTypes.contains(it.next())) {
                throw new IllegalArgumentException("A DataType that is not supported by this method was detected in the Dataframe.");
            }
        }
        super.fit(dataframe);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        int size = dataframe.size();
        int xColumnSize = dataframe.xColumnSize();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        RealMatrix x = DataframeMatrix.newInstance(dataframe, false, null, featureIds).getX();
        OpenMapRealVector openMapRealVector = new OpenMapRealVector(xColumnSize);
        for (Integer num : featureIds.values()) {
            double d = 0.0d;
            for (int i = 0; i < size; i++) {
                d += x.getEntry(i, num.intValue());
            }
            double d2 = d / size;
            for (int i2 = 0; i2 < size; i2++) {
                x.addToEntry(i2, num.intValue(), -d2);
            }
            openMapRealVector.setEntry(num.intValue(), d2);
        }
        modelParameters.setMean(openMapRealVector);
        EigenDecomposition eigenDecomposition = new EigenDecomposition(x.transpose().multiply(x).scalarMultiply(1.0d / (size - 1.0d)));
        RealVector arrayRealVector = new ArrayRealVector(eigenDecomposition.getRealEigenvalues(), false);
        RealMatrix v = eigenDecomposition.getV();
        if (((TrainingParameters) this.knowledgeBase.getTrainingParameters()).isWhitened()) {
            DiagonalMatrix diagonalMatrix = new DiagonalMatrix(xColumnSize);
            for (int i3 = 0; i3 < xColumnSize; i3++) {
                diagonalMatrix.setEntry(i3, i3, FastMath.sqrt(arrayRealVector.getEntry(i3)));
            }
            v = v.multiply(diagonalMatrix);
        }
        Integer maxDimensions = ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getMaxDimensions();
        Double variancePercentageThreshold = ((TrainingParameters) this.knowledgeBase.getTrainingParameters()).getVariancePercentageThreshold();
        if (variancePercentageThreshold != null && variancePercentageThreshold.doubleValue() <= 1.0d) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < xColumnSize; i4++) {
                d3 += arrayRealVector.getEntry(i4);
            }
            double d4 = 0.0d;
            int i5 = 0;
            for (int i6 = 0; i6 < xColumnSize; i6++) {
                d4 += arrayRealVector.getEntry(i6) / d3;
                i5++;
                if (d4 >= variancePercentageThreshold.doubleValue()) {
                    break;
                }
            }
            if (maxDimensions == null || maxDimensions.intValue() > i5) {
                maxDimensions = Integer.valueOf(i5);
            }
        }
        if (maxDimensions != null && maxDimensions.intValue() < xColumnSize) {
            arrayRealVector = arrayRealVector.getSubVector(0, maxDimensions.intValue());
            v = v.getSubMatrix(0, v.getRowDimension() - 1, 0, maxDimensions.intValue() - 1);
        }
        modelParameters.setEigenValues(arrayRealVector);
        modelParameters.setComponents(v);
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    protected void _transform(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) this.knowledgeBase.getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        HashMap hashMap = new HashMap();
        DataframeMatrix parseDataset = DataframeMatrix.parseDataset(dataframe, hashMap, featureIds);
        RealMatrix multiply = parseDataset.getX().multiply(modelParameters.getComponents());
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.entries(), isParallelized()), entry -> {
            Integer num = (Integer) entry.getKey();
            Record record = (Record) entry.getValue();
            int intValue = ((Integer) hashMap.get(num)).intValue();
            AssociativeArray associativeArray = new AssociativeArray();
            int i = 0;
            for (double d : multiply.getRow(intValue)) {
                int i2 = i;
                i++;
                associativeArray.put(Integer.valueOf(i2), Double.valueOf(d));
            }
            dataframe._unsafe_set(num, new Record(associativeArray, record.getY(), record.getYPredicted(), record.getYPredictedProbabilities()));
        });
        dataframe.recalculateMeta();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    public Set<TypeInference.DataType> getSupportedXDataTypes() {
        return new HashSet(Arrays.asList(TypeInference.DataType.BOOLEAN, TypeInference.DataType.NUMERICAL));
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    protected Set<TypeInference.DataType> getSupportedYDataTypes() {
        return null;
    }
}
