/*
 * Decompiled with CFR 0.152.
 */
package org.jdmp.core.algorithm.regression;

import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.collections.list.FastArrayList;
import org.ujmp.core.util.MathUtil;

public class LinearRegressionGradientDescent
extends AbstractClassifier {
    private static final long serialVersionUID = 9107143531298310638L;
    public static final String PARAMETERS = "Parameters";
    public static final String GRADIENT = "Gradient";
    private int batchSize = 1000;
    private double eta = 0.01;
    private int epochs = 100;
    private double minImprovement = 1.0E-6;
    private double weightDecay = 0.0;

    public LinearRegressionGradientDescent() {
        this.setParameterVariable(Variable.Factory.labeledVariable("Regression Parameters"));
        this.setParameterVariable(Variable.Factory.labeledVariable(GRADIENT));
    }

    public void setParameterVariable(Variable variable) {
        this.setVariable(PARAMETERS, variable);
    }

    public Variable getParameterVariable() {
        return (Variable)this.getVariableMap().get(PARAMETERS);
    }

    public void setGradientVariable(Variable variable) {
        this.setVariable(GRADIENT, variable);
    }

    public Variable getGradientVariable() {
        return (Variable)this.getVariableMap().get(GRADIENT);
    }

    public Matrix getParameterMatrix() {
        return (Matrix)this.getParameterVariable().getLast();
    }

    public Matrix getGradientMatrix() {
        return (Matrix)this.getGradientVariable().getLast();
    }

    @Override
    public Matrix predictOne(Matrix input) {
        input = input.toColumnVector(Calculation.Ret.NEW);
        Object bias = Matrix.Factory.ones(input.getRowCount(), 1L);
        Matrix data = Matrix.Factory.horCat(new Matrix[]{bias, input});
        Matrix result = this.getParameterMatrix().transpose().mtimes(data.transpose());
        return result.transpose();
    }

    @Override
    public void trainAll(ListDataSet dataSet) {
        double lastRmse = Double.MAX_VALUE;
        double averageImprovement = Double.NaN;
        int featureCount = this.getFeatureCount(dataSet);
        int classCount = this.getClassCount(dataSet);
        Matrix parameters = ((DenseMatrix)Matrix.Factory.randn((long)(featureCount + 1), (long)classCount)).divide(featureCount + 1);
        FastArrayList<Matrix> inputs = new FastArrayList<Matrix>();
        FastArrayList<Matrix> targets = new FastArrayList<Matrix>();
        int e = 0;
        while ((double)e < (double)dataSet.size() / (double)this.batchSize * (double)this.epochs) {
            inputs.clear();
            targets.clear();
            for (int i = 0; i < this.batchSize; ++i) {
                int randomIndex = MathUtil.nextInteger(dataSet.size());
                Sample s = (Sample)dataSet.get(randomIndex);
                inputs.add(s.getAsMatrix("Input").toColumnVector(Calculation.Ret.NEW));
                targets.add(s.getAsMatrix("Target").toColumnVector(Calculation.Ret.NEW));
            }
            Matrix input = Matrix.Factory.vertCat(inputs);
            Object bias = Matrix.Factory.ones(input.getRowCount(), 1L);
            Matrix x = Matrix.Factory.horCat(new Matrix[]{bias, input});
            x = x.transpose();
            Matrix target = Matrix.Factory.vertCat(targets);
            target = target.transpose();
            Matrix y = parameters.transpose().mtimes(x);
            Matrix diff = y.minus(target);
            Matrix squared = diff.power(Calculation.Ret.NEW, 2.0);
            double rmse = Math.sqrt(squared.getValueSum() / (double)squared.getValueCount());
            double improvement = Math.abs(lastRmse - rmse);
            if (MathUtil.isNaNOrInfinite(averageImprovement)) {
                averageImprovement = improvement;
            }
            averageImprovement = averageImprovement * 0.99 + 0.01 * improvement;
            System.out.println(rmse);
            Matrix gradient = x.mtimes(diff.transpose()).times(this.eta / (double)this.batchSize);
            this.setGradientMatrix(gradient);
            parameters = parameters.minus(gradient);
            parameters = parameters.times(1.0 - this.weightDecay);
            if (averageImprovement < this.minImprovement) break;
            lastRmse = rmse;
            ++e;
        }
        this.setParameterMatrix(parameters);
    }

    public void setParameterMatrix(Matrix parameters) {
        this.getVariableMap().setMatrix(PARAMETERS, parameters);
    }

    public void setGradientMatrix(Matrix gradient) {
        this.getVariableMap().setMatrix(GRADIENT, gradient);
    }

    @Override
    public void reset() {
        this.getParameterVariable().clear();
    }

    @Override
    public Classifier emptyCopy() {
        return new LinearRegressionGradientDescent();
    }
}

