package recunn.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.matrix.Matrix;

/* loaded from: input_file:recunn/model/GruLayer.class */
public class GruLayer implements Model {
    private static final long serialVersionUID = 1;
    int inputDimension;
    int outputDimension;
    Matrix IHmix;
    Matrix HHmix;
    Matrix Bmix;
    Matrix IHnew;
    Matrix HHnew;
    Matrix Bnew;
    Matrix IHreset;
    Matrix HHreset;
    Matrix Breset;
    Matrix context;
    Nonlinearity fMix = new SigmoidUnit();
    Nonlinearity fReset = new SigmoidUnit();
    Nonlinearity fNew = new TanhUnit();

    public GruLayer(int i, int i2, double d, Random random) {
        this.inputDimension = i;
        this.outputDimension = i2;
        this.IHmix = Matrix.rand(i2, i, d, random);
        this.HHmix = Matrix.rand(i2, i2, d, random);
        this.Bmix = new Matrix(i2);
        this.IHnew = Matrix.rand(i2, i, d, random);
        this.HHnew = Matrix.rand(i2, i2, d, random);
        this.Bnew = new Matrix(i2);
        this.IHreset = Matrix.rand(i2, i, d, random);
        this.HHreset = Matrix.rand(i2, i2, d, random);
        this.Breset = new Matrix(i2);
    }

    @Override // recunn.model.Model
    public Matrix forward(Matrix matrix, Graph graph) throws Exception {
        Matrix nonlin = graph.nonlin(this.fMix, graph.add(graph.add(graph.mul(this.IHmix, matrix), graph.mul(this.HHmix, this.context)), this.Bmix));
        Matrix nonlin2 = graph.nonlin(this.fReset, graph.add(graph.add(graph.mul(this.IHreset, matrix), graph.mul(this.HHreset, this.context)), this.Breset));
        Matrix add = graph.add(graph.elmul(nonlin, this.context), graph.elmul(graph.oneMinus(nonlin), graph.nonlin(this.fNew, graph.add(graph.add(graph.mul(this.IHnew, matrix), graph.mul(this.HHnew, graph.elmul(nonlin2, this.context))), this.Bnew))));
        this.context = add;
        return add;
    }

    @Override // recunn.model.Model
    public void resetState() {
        this.context = new Matrix(this.outputDimension);
    }

    @Override // recunn.model.Model
    public List<Matrix> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.IHmix);
        arrayList.add(this.HHmix);
        arrayList.add(this.Bmix);
        arrayList.add(this.IHnew);
        arrayList.add(this.HHnew);
        arrayList.add(this.Bnew);
        arrayList.add(this.IHreset);
        arrayList.add(this.HHreset);
        arrayList.add(this.Breset);
        return arrayList;
    }
}
