package org.jdmp.core.algorithm.compression;

import java.util.Iterator;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

/* loaded from: input_file:org/jdmp/core/algorithm/compression/PCA.class */
public class PCA extends AbstractCompressor {
    private static final long serialVersionUID = 4559351198783166902L;
    private Matrix mean;
    private Matrix std;
    private Matrix u;
    private int numberOfPrincipalComponents;

    public PCA(int i) {
        this.mean = null;
        this.std = null;
        this.u = null;
        this.numberOfPrincipalComponents = -1;
        this.numberOfPrincipalComponents = i;
    }

    public PCA() {
        this(0);
    }

    @Override // org.jdmp.core.algorithm.compression.Compressor
    public void reset() {
        this.mean = null;
        this.std = null;
        this.u = null;
    }

    @Override // org.jdmp.core.algorithm.compression.Compressor
    public void train(ListDataSet listDataSet) {
        System.out.println("training started");
        DenseMatrix zeros = Matrix.Factory.zeros(listDataSet.size(), getFeatureCount(listDataSet));
        int i = 0;
        Iterator it = listDataSet.iterator();
        while (it.hasNext()) {
            Matrix columnVector = ((Sample) it.next()).getAsMatrix(getInputLabel()).toColumnVector(Calculation.Ret.LINK);
            for (int i2 = 0; i2 < columnVector.getColumnCount(); i2++) {
                zeros.setAsDouble(columnVector.getAsDouble(new long[]{0, i2}), new long[]{i, i2});
            }
            i++;
        }
        System.out.println("data loaded");
        this.mean = zeros.mean(Calculation.Ret.NEW, 0, true);
        for (int i3 = 0; i3 < zeros.getRowCount(); i3++) {
            for (int i4 = 0; i4 < zeros.getColumnCount(); i4++) {
                zeros.setAsDouble(zeros.getAsDouble(new long[]{i3, i4}) - this.mean.getAsDouble(new long[]{0, i4}), new long[]{i3, i4});
            }
        }
        this.std = zeros.std(Calculation.Ret.NEW, 0, true, true);
        for (int i5 = 0; i5 < zeros.getRowCount(); i5++) {
            for (int i6 = 0; i6 < zeros.getColumnCount(); i6++) {
                zeros.setAsDouble(zeros.getAsDouble(new long[]{i5, i6}) / this.std.getAsDouble(new long[]{0, i6}), new long[]{i5, i6});
            }
        }
        Matrix mtimes = zeros.transpose().mtimes(zeros);
        this.u = (this.numberOfPrincipalComponents == 0 ? mtimes.svd() : mtimes.svd(this.numberOfPrincipalComponents))[0];
        System.out.println("training finished");
    }

    @Override // org.jdmp.core.algorithm.compression.Compressor
    public Matrix compress(Matrix matrix) {
        return matrix.toColumnVector(Calculation.Ret.LINK).minus(Calculation.Ret.LINK, true, this.mean).divide(this.std).mtimes(this.u);
    }

    @Override // org.jdmp.core.algorithm.compression.Compressor
    public Matrix decompress(Matrix matrix) {
        return null;
    }
}
