/*
 * Decompiled with CFR 0.152.
 */
package jsat.math.optimization.stochastic;

import jsat.linear.ConstantVector;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.optimization.stochastic.GradientUpdater;

public class AdaGrad
implements GradientUpdater {
    private static final long serialVersionUID = 5138474612999751777L;
    private Vec daigG;
    private double biasG;

    public AdaGrad() {
    }

    public AdaGrad(AdaGrad toCopy) {
        if (toCopy.daigG != null) {
            this.daigG = toCopy.daigG.clone();
        }
        this.biasG = toCopy.biasG;
    }

    @Override
    public void update(Vec x, Vec grad, double eta) {
        this.update(x, grad, eta, 0.0, 0.0);
    }

    @Override
    public double update(Vec x, Vec grad, double eta, double bias, double biasGrad) {
        for (IndexValue iv : grad) {
            int indx = iv.getIndex();
            double grad_i = iv.getValue();
            double g_ii = this.daigG.get(indx);
            x.increment(indx, -eta * grad_i / Math.sqrt(g_ii));
            this.daigG.increment(indx, grad_i * grad_i);
        }
        double biasUpdate = eta * biasGrad / Math.sqrt(this.biasG);
        this.biasG += biasGrad * biasGrad;
        return biasUpdate;
    }

    @Override
    public AdaGrad clone() {
        return new AdaGrad(this);
    }

    @Override
    public void setup(int d) {
        this.daigG = new DenseVector(new ConstantVector(1.0, d));
        this.biasG = 1.0;
    }
}

