package jsat.regression;

import java.util.Collections;
import java.util.Iterator;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.NoDecay;
import jsat.parameters.Parameterized;
import jsat.utils.IntList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/regression/StochasticRidgeRegression.class */
public class StochasticRidgeRegression implements Regressor, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = -3462783438115627128L;
    private double lambda;
    private int epochs;
    private int batchSize;
    private double learningRate;
    private DecayRate learningDecay;
    private Vec w;
    private double bias;

    public StochasticRidgeRegression(double d, int i, int i2, double d2) {
        this(d, i, i2, d2, new NoDecay());
    }

    public StochasticRidgeRegression(double d, int i, int i2, double d2, DecayRate decayRate) {
        setLambda(d);
        setEpochs(i);
        setBatchSize(i2);
        setLearningRate(d2);
        setLearningDecay(decayRate);
    }

    public void setLambda(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("lambda must be a positive constant, not " + d);
        }
        this.lambda = d;
    }

    public double getLambda() {
        return this.lambda;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningDecay(DecayRate decayRate) {
        this.learningDecay = decayRate;
    }

    public DecayRate getLearningDecay() {
        return this.learningDecay;
    }

    public void setBatchSize(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Batch size must be a positive constant, not " + i);
        }
        this.batchSize = i;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setEpochs(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("At least one epoch must be performed, can not use " + i);
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.bias;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return regress(dataPoint.getNumericalValues());
    }

    private double regress(Vec vec) {
        return this.w.dot(vec) + this.bias;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, boolean z) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        int min = Math.min(this.batchSize, regressionDataSet.getSampleSize());
        this.w = new DenseVector(regressionDataSet.getNumNumericalVars());
        IntList intList = new IntList(regressionDataSet.getSampleSize());
        ListUtils.addRange(intList, 0, regressionDataSet.getSampleSize(), 1);
        int i = 0;
        double[] dArr = new double[min];
        int i2 = 0;
        for (int i3 = 0; i3 < regressionDataSet.getSampleSize(); i3++) {
            if (regressionDataSet.getDataPoint(i3).getNumericalValues().isSparse()) {
                i2++;
            }
        }
        boolean z = i2 > regressionDataSet.getSampleSize() / 4;
        int[] iArr = z ? new int[this.w.length()] : null;
        for (int i4 = 0; i4 < this.epochs; i4++) {
            Collections.shuffle(intList);
            double rate = this.learningDecay.rate(i4, this.epochs, this.learningRate) / min;
            double d = rate * this.lambda;
            int i5 = 0;
            while (true) {
                int i6 = i5;
                if (i6 >= intList.size()) {
                    break;
                }
                if (i6 + min < intList.size()) {
                    i++;
                    for (int i7 = i6; i7 < i6 + min; i7++) {
                        dArr[i7 - i6] = regress(regressionDataSet.getDataPoint(intList.get(i6).intValue())) - regressionDataSet.getTargetValue(intList.get(i6).intValue());
                    }
                    for (int i8 = i6; i8 < i6 + min; i8++) {
                        double d2 = rate * dArr[i8 - i6];
                        this.bias -= d2;
                        Vec numericalValues = regressionDataSet.getDataPoint(intList.get(i6).intValue()).getNumericalValues();
                        if (z) {
                            Iterator<IndexValue> it = numericalValues.iterator();
                            while (it.hasNext()) {
                                IndexValue next = it.next();
                                int index = next.getIndex();
                                if (iArr[index] != i) {
                                    this.w.set(index, this.w.get(index) * Math.pow(1.0d - d, i - iArr[index]));
                                    iArr[index] = i;
                                }
                                this.w.increment(index, (-d2) * next.getValue());
                            }
                        } else {
                            if (i8 == i6) {
                                this.w.mutableMultiply(1.0d - d);
                            }
                            this.w.mutableSubtract(d2, numericalValues);
                        }
                    }
                }
                i5 = i6 + min;
            }
            if (z && (!(this.learningDecay instanceof NoDecay) || i4 == this.epochs - 1)) {
                for (int i9 = 0; i9 < this.w.length(); i9++) {
                    if (iArr[i9] != i) {
                        this.w.set(i9, this.w.get(i9) * Math.pow(1.0d - d, i - iArr[i9]));
                        iArr[i9] = i;
                    }
                }
            }
        }
    }

    @Override // jsat.regression.Regressor
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.regression.Regressor
    public StochasticRidgeRegression clone() {
        StochasticRidgeRegression stochasticRidgeRegression = new StochasticRidgeRegression(this.lambda, this.epochs, this.batchSize, this.learningRate, this.learningDecay);
        if (this.w != null) {
            stochasticRidgeRegression.w = this.w.mo46clone();
        }
        stochasticRidgeRegression.bias = this.bias;
        return stochasticRidgeRegression;
    }
}
