/*
 * Decompiled with CFR 0.152.
 */
package jsat.regression;

import java.util.Arrays;
import jsat.classifiers.DataPoint;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.concurrent.ParallelUtils;

public class RidgeRegression
implements Regressor,
Parameterized {
    private static final long serialVersionUID = -4605757038780391895L;
    private double lambda;
    private Vec w;
    private double bias;
    private SolverMode mode;

    public RidgeRegression() {
        this(0.01);
    }

    public RidgeRegression(double regularization) {
        this(regularization, SolverMode.EXACT_CHOLESKY);
    }

    public RidgeRegression(double regularization, SolverMode mode) {
        this.setLambda(regularization);
        this.setSolverMode(mode);
    }

    public void setLambda(double lambda) {
        if (Double.isNaN(lambda) || Double.isInfinite(lambda) || lambda <= 0.0) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + lambda);
        }
        this.lambda = lambda;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setSolverMode(SolverMode mode) {
        this.mode = mode;
    }

    public SolverMode getSolverMode() {
        return this.mode;
    }

    @Override
    public double regress(DataPoint data) {
        Vec x = data.getNumericalValues();
        return this.w.dot(x) + this.bias;
    }

    @Override
    public void train(RegressionDataSet dataSet, boolean parallel) {
        boolean serial;
        int dim = dataSet.getNumNumericalVars() + 1;
        DenseMatrix X = new DenseMatrix(dataSet.getSampleSize(), dim);
        for (int i = 0; i < dataSet.getSampleSize(); ++i) {
            Vec from = dataSet.getDataPoint(i).getNumericalValues();
            X.set(i, 0, 1.0);
            for (int j = 0; j < from.length(); ++j) {
                X.set(i, j + 1, from.get(j));
            }
        }
        Vec Y = dataSet.getTargetValues();
        boolean bl = serial = !parallel;
        if (this.mode == SolverMode.EXACT_SVD) {
            SingularValueDecomposition svd = new SingularValueDecomposition(X);
            double[] ridgeD = Arrays.copyOf(svd.getSingularValues(), dim);
            for (int i = 0; i < ridgeD.length; ++i) {
                ridgeD[i] = 1.0 / (Math.pow(ridgeD[i], 2.0) + this.lambda);
            }
            Matrix U = svd.getU();
            Matrix V = svd.getV();
            Matrix.diagMult(V, DenseVector.toDenseVec(ridgeD));
            Matrix.diagMult(V, DenseVector.toDenseVec(svd.getSingularValues()));
            this.w = V.multiply(U.transpose()).multiply(Y);
        } else {
            Matrix H = serial ? X.transposeMultiply(X) : X.transposeMultiply((Matrix)X, ParallelUtils.CACHED_THREAD_POOL);
            for (int i = 0; i < H.rows(); ++i) {
                H.increment(i, i, this.lambda);
            }
            CholeskyDecomposition cd = serial ? new CholeskyDecomposition(H) : new CholeskyDecomposition(H, ParallelUtils.CACHED_THREAD_POOL);
            this.w = cd.solve(Matrix.eye(H.rows())).multiply(X.transpose()).multiply(Y);
        }
        this.bias = this.w.get(0);
        DenseVector newW = new DenseVector(this.w.length() - 1);
        for (int i = 0; i < ((Vec)newW).length(); ++i) {
            ((Vec)newW).set(i, this.w.get(i + 1));
        }
        this.w = newW;
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public RidgeRegression clone() {
        RidgeRegression clone = new RidgeRegression(this.lambda);
        if (this.w != null) {
            clone.w = this.w.clone();
        }
        clone.bias = this.bias;
        return clone;
    }

    public static enum SolverMode {
        EXACT_CHOLESKY,
        EXACT_SVD;

    }
}

