/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.compression;

import org.jdmp.core.algorithm.compression.AbstractCompressor;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

public class PCA
extends AbstractCompressor {
    private static final long serialVersionUID = 4559351198783166902L;
    private Matrix mean = null;
    private Matrix std = null;
    private Matrix u = null;
    private int numberOfPrincipalComponents = -1;

    public PCA(int numberOfPrincipalComponents) {
        this.numberOfPrincipalComponents = numberOfPrincipalComponents;
    }

    public PCA() {
        this(0);
    }

    @Override
    public void reset() {
        this.mean = null;
        this.std = null;
        this.u = null;
    }

    @Override
    public void train(ListDataSet dataSet) {
        System.out.println("training started");
        DenseMatrix x = Matrix.Factory.zeros((long)dataSet.size(), (long)this.getFeatureCount(dataSet));
        int i = 0;
        for (Sample s : dataSet) {
            Matrix input = s.getAsMatrix(this.getInputLabel()).toColumnVector(Calculation.Ret.LINK);
            int c = 0;
            while ((long)c < input.getColumnCount()) {
                x.setAsDouble(input.getAsDouble(0L, c), i, c);
                ++c;
            }
            ++i;
        }
        System.out.println("data loaded");
        this.mean = x.mean(Calculation.Ret.NEW, 0, true);
        int r = 0;
        while ((long)r < x.getRowCount()) {
            int c = 0;
            while ((long)c < x.getColumnCount()) {
                x.setAsDouble(x.getAsDouble(r, c) - this.mean.getAsDouble(0L, c), r, c);
                ++c;
            }
            ++r;
        }
        this.std = x.std(Calculation.Ret.NEW, 0, true, true);
        r = 0;
        while ((long)r < x.getRowCount()) {
            int c = 0;
            while ((long)c < x.getColumnCount()) {
                x.setAsDouble(x.getAsDouble(r, c) / this.std.getAsDouble(0L, c), r, c);
                ++c;
            }
            ++r;
        }
        Matrix xtx = x.transpose().mtimes(x);
        Matrix[] svd = this.numberOfPrincipalComponents == 0 ? xtx.svd() : xtx.svd(this.numberOfPrincipalComponents);
        this.u = svd[0];
        System.out.println("training finished");
    }

    @Override
    public Matrix compress(Matrix input) {
        return input.toColumnVector(Calculation.Ret.LINK).minus(Calculation.Ret.LINK, true, this.mean).divide(this.std).mtimes(this.u);
    }

    @Override
    public Matrix decompress(Matrix input) {
        return null;
    }
}

