package recunn.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.matrix.Matrix;

/* loaded from: input_file:recunn/model/RnnLayer.class */
public class RnnLayer implements Model {
    private static final long serialVersionUID = 1;
    int inputDimension;
    int outputDimension;
    Matrix W;
    Matrix b;
    Matrix context;
    Nonlinearity f;

    public RnnLayer(int i, int i2, Nonlinearity nonlinearity, double d, Random random) {
        this.inputDimension = i;
        this.outputDimension = i2;
        this.f = nonlinearity;
        this.W = Matrix.rand(i2, i + i2, d, random);
        this.b = new Matrix(i2);
    }

    @Override // recunn.model.Model
    public Matrix forward(Matrix matrix, Graph graph) throws Exception {
        Matrix nonlin = graph.nonlin(this.f, graph.add(graph.mul(this.W, graph.concatVectors(matrix, this.context)), this.b));
        this.context = nonlin;
        return nonlin;
    }

    @Override // recunn.model.Model
    public void resetState() {
        this.context = new Matrix(this.outputDimension);
    }

    @Override // recunn.model.Model
    public List<Matrix> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.W);
        arrayList.add(this.b);
        return arrayList;
    }
}
