package org.jdmp.core.algorithm.regression;

import org.jdmp.core.algorithm.classification.AbstractClassifier;
import org.jdmp.core.algorithm.classification.Classifier;
import org.jdmp.core.dataset.ListDataSet;
import org.jdmp.core.sample.Sample;
import org.jdmp.core.variable.Variable;
import org.ujmp.core.Matrix;
import org.ujmp.core.calculation.Calculation;
import org.ujmp.core.collections.list.FastArrayList;
import org.ujmp.core.util.MathUtil;

/* loaded from: input_file:org/jdmp/core/algorithm/regression/LinearRegressionGradientDescent.class */
public class LinearRegressionGradientDescent extends AbstractClassifier {
    private static final long serialVersionUID = 9107143531298310638L;
    public static final String PARAMETERS = "Parameters";
    public static final String GRADIENT = "Gradient";
    private int batchSize = 1000;
    private double eta = 0.01d;
    private int epochs = 100;
    private double minImprovement = 1.0E-6d;
    private double weightDecay = 0.0d;

    public LinearRegressionGradientDescent() {
        setParameterVariable(Variable.Factory.labeledVariable("Regression Parameters"));
        setParameterVariable(Variable.Factory.labeledVariable(GRADIENT));
    }

    public void setParameterVariable(Variable variable) {
        setVariable("Parameters", variable);
    }

    public Variable getParameterVariable() {
        return (Variable) getVariableMap().get("Parameters");
    }

    public void setGradientVariable(Variable variable) {
        setVariable(GRADIENT, variable);
    }

    public Variable getGradientVariable() {
        return (Variable) getVariableMap().get(GRADIENT);
    }

    public Matrix getParameterMatrix() {
        return (Matrix) getParameterVariable().getLast();
    }

    public Matrix getGradientMatrix() {
        return (Matrix) getGradientVariable().getLast();
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Matrix predictOne(Matrix matrix) {
        Matrix columnVector = matrix.toColumnVector(Calculation.Ret.NEW);
        return getParameterMatrix().transpose().mtimes(Matrix.Factory.horCat(new Matrix[]{Matrix.Factory.ones(columnVector.getRowCount(), 1L), columnVector}).transpose()).transpose();
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void trainAll(ListDataSet listDataSet) {
        double d = Double.MAX_VALUE;
        double d2 = Double.NaN;
        int featureCount = getFeatureCount(listDataSet);
        Matrix divide = Matrix.Factory.randn(featureCount + 1, getClassCount(listDataSet)).divide(featureCount + 1);
        FastArrayList fastArrayList = new FastArrayList();
        FastArrayList fastArrayList2 = new FastArrayList();
        for (int i = 0; i < (listDataSet.size() / this.batchSize) * this.epochs; i++) {
            fastArrayList.clear();
            fastArrayList2.clear();
            for (int i2 = 0; i2 < this.batchSize; i2++) {
                Sample sample = (Sample) listDataSet.get(MathUtil.nextInteger(listDataSet.size()));
                fastArrayList.add(sample.getAsMatrix("Input").toColumnVector(Calculation.Ret.NEW));
                fastArrayList2.add(sample.getAsMatrix("Target").toColumnVector(Calculation.Ret.NEW));
            }
            Matrix vertCat = Matrix.Factory.vertCat(fastArrayList);
            Matrix transpose = Matrix.Factory.horCat(new Matrix[]{Matrix.Factory.ones(vertCat.getRowCount(), 1L), vertCat}).transpose();
            Matrix minus = divide.transpose().mtimes(transpose).minus(Matrix.Factory.vertCat(fastArrayList2).transpose());
            double sqrt = Math.sqrt(minus.power(Calculation.Ret.NEW, 2.0d).getValueSum() / r0.getValueCount());
            double abs = Math.abs(d - sqrt);
            if (MathUtil.isNaNOrInfinite(d2)) {
                d2 = abs;
            }
            d2 = (d2 * 0.99d) + (0.01d * abs);
            System.out.println(sqrt);
            Matrix times = transpose.mtimes(minus.transpose()).times(this.eta / this.batchSize);
            setGradientMatrix(times);
            divide = divide.minus(times).times(1.0d - this.weightDecay);
            if (d2 < this.minImprovement) {
                break;
            }
            d = sqrt;
        }
        setParameterMatrix(divide);
    }

    public void setParameterMatrix(Matrix matrix) {
        getVariableMap().setMatrix("Parameters", matrix);
    }

    public void setGradientMatrix(Matrix matrix) {
        getVariableMap().setMatrix(GRADIENT, matrix);
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public void reset() {
        getParameterVariable().clear();
    }

    @Override // org.jdmp.core.algorithm.regression.Regressor
    public Classifier emptyCopy() {
        return new LinearRegressionGradientDescent();
    }
}
