/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.regression;

import org.jdmp.core.algorithm.regression.AbstractRegressor;
import org.jdmp.core.algorithm.regression.Regressor;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;

public class LinearRegression
extends AbstractRegressor {
    private static final long serialVersionUID = 3483912497269476834L;
    public static final String PARAMETERS = "Parameters";
    private Matrix mean = null;
    private Matrix std = null;
    private final int numberOfPrincipalComponents;

    public LinearRegression() {
        this(0);
    }

    public LinearRegression(int numberOfPrincipalComponents) {
        this.numberOfPrincipalComponents = numberOfPrincipalComponents;
        this.setParameterVariable(Variable.Factory.labeledVariable("Regression Parameters"));
    }

    public void setParameterVariable(Variable variable) {
        this.setVariable(PARAMETERS, variable);
    }

    public Variable getParameterVariable() {
        return (Variable)this.getVariableMap().get(PARAMETERS);
    }

    public Matrix getParameterMatrix() {
        return (Matrix)this.getParameterVariable().getLast();
    }

    @Override
    public Matrix predictOne(Matrix input) {
        input = input.toColumnVector(Calculation.Ret.NEW);
        Matrix x = Matrix.Factory.zeros(1L, input.getColumnCount() + 1L);
        int c = 0;
        while ((long)c < input.getColumnCount()) {
            x.setAsDouble(input.getAsDouble(0L, c), 0L, c + 1);
            ++c;
        }
        x = x.minus(this.mean).divide(this.std);
        x.setAsDouble(1.0, 0L, 0L);
        Matrix result = x.mtimes(this.getParameterMatrix());
        return result;
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        Matrix parameters;
        System.out.println("training started");
        int featureCount = this.getFeatureCount(dataSet);
        int classCount = this.getClassCount(dataSet);
        int sampleCount = dataSet.size();
        DenseMatrix x = Matrix.Factory.zeros((long)dataSet.size(), (long)(featureCount + 1));
        DenseMatrix y = Matrix.Factory.zeros((long)dataSet.size(), (long)classCount);
        int i = 0;
        for (Sample s : dataSet) {
            Matrix input = s.getAsMatrix(this.getInputLabel()).toColumnVector(Calculation.Ret.LINK);
            int c = 0;
            while ((long)c < input.getColumnCount()) {
                x.setAsDouble(input.getAsDouble(0L, c), i, c + 1);
                ++c;
            }
            Matrix target = s.getAsMatrix(this.getTargetLabel()).toColumnVector(Calculation.Ret.LINK);
            int c2 = 0;
            while ((long)c2 < target.getColumnCount()) {
                y.setAsDouble(target.getAsDouble(0L, c2), i, c2);
                ++c2;
            }
            ++i;
        }
        this.mean = x.mean(Calculation.Ret.NEW, 0, true);
        int r = 0;
        while ((long)r < x.getRowCount()) {
            int c = 0;
            while ((long)c < x.getColumnCount()) {
                x.setAsDouble(x.getAsDouble(r, c) - this.mean.getAsDouble(0L, c), r, c);
                ++c;
            }
            ++r;
        }
        this.std = x.std(Calculation.Ret.NEW, 0, true, true);
        r = 0;
        while ((long)r < x.getRowCount()) {
            int c = 0;
            while ((long)c < x.getColumnCount()) {
                double s = this.std.getAsDouble(0L, c);
                if (s == 0.0) {
                    x.setAsDouble(x.getAsDouble(r, c), r, c);
                } else {
                    x.setAsDouble(x.getAsDouble(r, c) / s, r, c);
                }
                ++c;
            }
            ++r;
        }
        r = 0;
        while ((long)r < x.getRowCount()) {
            x.setAsDouble(1.0, r, 0L);
            ++r;
        }
        System.out.println("data loaded");
        if (this.numberOfPrincipalComponents > 0 && (long)this.numberOfPrincipalComponents < Math.min(x.getRowCount(), x.getColumnCount() - 1L)) {
            if (sampleCount < featureCount) {
                parameters = x.pinv(this.numberOfPrincipalComponents).mtimes(y);
            } else {
                Matrix xt = x.transpose();
                parameters = xt.mtimes(x).pinv(this.numberOfPrincipalComponents).mtimes(xt).mtimes(y);
            }
        } else if (sampleCount < featureCount) {
            parameters = x.pinv().mtimes(y);
        } else {
            Matrix xt = x.transpose();
            parameters = xt.mtimes(x).pinv().mtimes(xt).mtimes(y);
        }
        System.out.println("training finished");
        this.setParameterMatrix(parameters);
    }

    public void setParameterMatrix(Matrix parameters) {
        this.getVariableMap().setMatrix(PARAMETERS, parameters);
    }

    @Override
    public void reset() {
        this.getParameterVariable().clear();
    }

    @Override
    public Regressor emptyCopy() {
        LinearRegression lr = new LinearRegression(this.numberOfPrincipalComponents);
        lr.setInputLabel(this.getInputLabel());
        lr.setTargetLabel(this.getTargetLabel());
        return lr;
    }
}

