/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import jsat.linear.DenseVector;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class SGDMomentum
implements GradientUpdater {
    private static final long serialVersionUID = -3837883539010356899L;
    private double momentum;
    private boolean nestrov;
    private Vec velocity;
    private double biasVelocity;

    public SGDMomentum(double momentum, boolean nestrov) {
        this.setMomentum(momentum);
        this.nestrov = nestrov;
    }

    public SGDMomentum(double momentum) {
        this(momentum, true);
    }

    public SGDMomentum(SGDMomentum toCopy) {
        this.momentum = toCopy.momentum;
        if (toCopy.velocity != null) {
            this.velocity = toCopy.velocity.clone();
        }
        this.biasVelocity = toCopy.biasVelocity;
    }

    public void setMomentum(double momentum) {
        if (momentum <= 0.0 || momentum >= 1.0 || Double.isNaN(momentum)) {
            throw new IllegalArgumentException("Momentum must be in (0,1) not " + momentum);
        }
        this.momentum = momentum;
    }

    public double getMomentum() {
        return this.momentum;
    }

    @Override
    public void update(Vec x, Vec grad, double eta) {
        this.update(x, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec x, Vec grad, double eta, double bias, double biasGrad) {
        double biasUpdate;
        if (this.nestrov) {
            x.mutableAdd(this.momentum * this.momentum, this.velocity);
            x.mutableSubtract((1.0 + this.momentum) * eta, grad);
            biasUpdate = -this.momentum * this.momentum * this.biasVelocity + (1.0 + this.momentum) * eta * biasGrad;
        } else {
            x.mutableAdd(this.momentum, this.velocity);
            x.mutableSubtract(eta, grad);
            biasUpdate = -this.momentum * this.biasVelocity + eta * biasGrad;
        }
        this.velocity.mutableMultiply(this.momentum);
        this.velocity.mutableSubtract(eta, grad);
        this.biasVelocity = this.biasVelocity * this.momentum - eta * biasGrad;
        return biasUpdate;
    }

    @Override
    public SGDMomentum clone() {
        return new SGDMomentum(this);
    }

    @Override
    public void setup(int d) {
        this.velocity = new ScaledVector(new DenseVector(d));
        this.biasVelocity = 0.0;
    }
}

