package recunn.model;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import recunn.autodiff.Graph;
import recunn.matrix.Matrix;

/* loaded from: input_file:recunn/model/LinearLayer.class */
public class LinearLayer implements Model {
    private static final long serialVersionUID = 1;
    Matrix W;

    public LinearLayer(int i, int i2, double d, Random random) {
        this.W = Matrix.rand(i2, i, d, random);
    }

    @Override // recunn.model.Model
    public Matrix forward(Matrix matrix, Graph graph) throws Exception {
        return graph.mul(this.W, matrix);
    }

    @Override // recunn.model.Model
    public void resetState() {
    }

    @Override // recunn.model.Model
    public List<Matrix> getParameters() {
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.W);
        return arrayList;
    }
}
