package recunn.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.matrix.Matrix;

/* loaded from: input_file:recunn/model/LstmLayer.class */
public class LstmLayer implements Model {
    private static final long serialVersionUID = 1;
    int inputDimension;
    int outputDimension;
    Matrix Wix;
    Matrix Wih;
    Matrix bi;
    Matrix Wfx;
    Matrix Wfh;
    Matrix bf;
    Matrix Wox;
    Matrix Woh;
    Matrix bo;
    Matrix Wcx;
    Matrix Wch;
    Matrix bc;
    Matrix hiddenContext;
    Matrix cellContext;
    Nonlinearity fInputGate = new SigmoidUnit();
    Nonlinearity fForgetGate = new SigmoidUnit();
    Nonlinearity fOutputGate = new SigmoidUnit();
    Nonlinearity fCellInput = new TanhUnit();
    Nonlinearity fCellOutput = new TanhUnit();

    public LstmLayer(int i, int i2, double d, Random random) {
        this.inputDimension = i;
        this.outputDimension = i2;
        this.Wix = Matrix.rand(i2, i, d, random);
        this.Wih = Matrix.rand(i2, i2, d, random);
        this.bi = new Matrix(i2);
        this.Wfx = Matrix.rand(i2, i, d, random);
        this.Wfh = Matrix.rand(i2, i2, d, random);
        this.bf = Matrix.ones(i2, 1);
        this.Wox = Matrix.rand(i2, i, d, random);
        this.Woh = Matrix.rand(i2, i2, d, random);
        this.bo = new Matrix(i2);
        this.Wcx = Matrix.rand(i2, i, d, random);
        this.Wch = Matrix.rand(i2, i2, d, random);
        this.bc = new Matrix(i2);
    }

    @Override // recunn.model.Model
    public Matrix forward(Matrix matrix, Graph graph) throws Exception {
        Matrix nonlin = graph.nonlin(this.fInputGate, graph.add(graph.add(graph.mul(this.Wix, matrix), graph.mul(this.Wih, this.hiddenContext)), this.bi));
        Matrix nonlin2 = graph.nonlin(this.fForgetGate, graph.add(graph.add(graph.mul(this.Wfx, matrix), graph.mul(this.Wfh, this.hiddenContext)), this.bf));
        Matrix nonlin3 = graph.nonlin(this.fOutputGate, graph.add(graph.add(graph.mul(this.Wox, matrix), graph.mul(this.Woh, this.hiddenContext)), this.bo));
        Matrix add = graph.add(graph.elmul(nonlin2, this.cellContext), graph.elmul(nonlin, graph.nonlin(this.fCellInput, graph.add(graph.add(graph.mul(this.Wcx, matrix), graph.mul(this.Wch, this.hiddenContext)), this.bc))));
        Matrix elmul = graph.elmul(nonlin3, graph.nonlin(this.fCellOutput, add));
        this.hiddenContext = elmul;
        this.cellContext = add;
        return elmul;
    }

    @Override // recunn.model.Model
    public void resetState() {
        this.hiddenContext = new Matrix(this.outputDimension);
        this.cellContext = new Matrix(this.outputDimension);
    }

    @Override // recunn.model.Model
    public List<Matrix> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.Wix);
        arrayList.add(this.Wih);
        arrayList.add(this.bi);
        arrayList.add(this.Wfx);
        arrayList.add(this.Wfh);
        arrayList.add(this.bf);
        arrayList.add(this.Wox);
        arrayList.add(this.Woh);
        arrayList.add(this.bo);
        arrayList.add(this.Wcx);
        arrayList.add(this.Wch);
        arrayList.add(this.bc);
        return arrayList;
    }
}
