package recunn.autodiff;

import java.util.ArrayList;
import java.util.List;
import recunn.matrix.Matrix;
import recunn.model.Nonlinearity;

/* loaded from: input_file:recunn/autodiff/Graph.class */
public class Graph {
    boolean applyBackprop;
    List<Runnable> backprop;

    public Graph() {
        this.backprop = new ArrayList();
        this.applyBackprop = true;
    }

    public Graph(boolean z) {
        this.backprop = new ArrayList();
        this.applyBackprop = z;
    }

    public void backward() {
        for (int size = this.backprop.size() - 1; size >= 0; size--) {
            this.backprop.get(size).run();
        }
    }

    public Matrix concatVectors(final Matrix matrix, final Matrix matrix2) throws Exception {
        if (matrix.cols > 1 || matrix2.cols > 1) {
            throw new Exception("Expected column vectors");
        }
        final Matrix matrix3 = new Matrix(matrix.rows + matrix2.rows);
        int i = 0;
        for (int i2 = 0; i2 < matrix.w.length; i2++) {
            matrix3.w[i] = matrix.w[i2];
            matrix3.dw[i] = matrix.dw[i2];
            matrix3.stepCache[i] = matrix.stepCache[i2];
            i++;
        }
        for (int i3 = 0; i3 < matrix2.w.length; i3++) {
            matrix3.w[i] = matrix2.w[i3];
            matrix3.dw[i] = matrix2.dw[i3];
            matrix3.stepCache[i] = matrix2.stepCache[i3];
            i++;
        }
        if (this.applyBackprop) {
            this.backprop.add(new Runnable() { // from class: recunn.autodiff.Graph.1
                @Override // java.lang.Runnable
                public void run() {
                    int i4 = 0;
                    for (int i5 = 0; i5 < matrix.w.length; i5++) {
                        matrix.w[i5] = matrix3.w[i4];
                        matrix.dw[i5] = matrix3.dw[i4];
                        matrix.stepCache[i5] = matrix3.stepCache[i4];
                        i4++;
                    }
                    for (int i6 = 0; i6 < matrix2.w.length; i6++) {
                        matrix2.w[i6] = matrix3.w[i4];
                        matrix2.dw[i6] = matrix3.dw[i4];
                        matrix2.stepCache[i6] = matrix3.stepCache[i4];
                        i4++;
                    }
                }
            });
        }
        return matrix3;
    }

    public Matrix nonlin(final Nonlinearity nonlinearity, final Matrix matrix) throws Exception {
        final Matrix matrix2 = new Matrix(matrix.rows, matrix.cols);
        final int length = matrix.w.length;
        for (int i = 0; i < length; i++) {
            matrix2.w[i] = nonlinearity.forward(matrix.w[i]);
        }
        if (this.applyBackprop) {
            this.backprop.add(new Runnable() { // from class: recunn.autodiff.Graph.2
                @Override // java.lang.Runnable
                public void run() {
                    for (int i2 = 0; i2 < length; i2++) {
                        double[] dArr = matrix.dw;
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + (nonlinearity.backward(matrix.w[i2]) * matrix2.dw[i2]);
                    }
                }
            });
        }
        return matrix2;
    }

    public Matrix mul(final Matrix matrix, final Matrix matrix2) throws Exception {
        if (matrix.cols != matrix2.rows) {
            throw new Exception("matrix dimension mismatch");
        }
        int i = matrix.rows;
        final int i2 = matrix.cols;
        final int i3 = matrix2.cols;
        final Matrix matrix3 = new Matrix(i, i3);
        for (int i4 = 0; i4 < i; i4++) {
            int i5 = i2 * i4;
            for (int i6 = 0; i6 < i3; i6++) {
                double d = 0.0d;
                for (int i7 = 0; i7 < i2; i7++) {
                    d += matrix.w[i5 + i7] * matrix2.w[(i3 * i7) + i6];
                }
                matrix3.w[(i3 * i4) + i6] = d;
            }
        }
        if (this.applyBackprop) {
            this.backprop.add(new Runnable() { // from class: recunn.autodiff.Graph.3
                @Override // java.lang.Runnable
                public void run() {
                    for (int i8 = 0; i8 < matrix.rows; i8++) {
                        int i9 = i3 * i8;
                        for (int i10 = 0; i10 < matrix2.cols; i10++) {
                            double d2 = matrix3.dw[i9 + i10];
                            for (int i11 = 0; i11 < matrix.cols; i11++) {
                                double[] dArr = matrix.dw;
                                int i12 = (i2 * i8) + i11;
                                dArr[i12] = dArr[i12] + (matrix2.w[(i3 * i11) + i10] * d2);
                                double[] dArr2 = matrix2.dw;
                                int i13 = (i3 * i11) + i10;
                                dArr2[i13] = dArr2[i13] + (matrix.w[(i2 * i8) + i11] * d2);
                            }
                        }
                    }
                }
            });
        }
        return matrix3;
    }

    public Matrix add(final Matrix matrix, final Matrix matrix2) throws Exception {
        if (matrix.rows != matrix2.rows || matrix.cols != matrix2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix matrix3 = new Matrix(matrix.rows, matrix.cols);
        for (int i = 0; i < matrix.w.length; i++) {
            matrix3.w[i] = matrix.w[i] + matrix2.w[i];
        }
        if (this.applyBackprop) {
            this.backprop.add(new Runnable() { // from class: recunn.autodiff.Graph.4
                @Override // java.lang.Runnable
                public void run() {
                    for (int i2 = 0; i2 < matrix.w.length; i2++) {
                        double[] dArr = matrix.dw;
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + matrix3.dw[i2];
                        double[] dArr2 = matrix2.dw;
                        int i4 = i2;
                        dArr2[i4] = dArr2[i4] + matrix3.dw[i2];
                    }
                }
            });
        }
        return matrix3;
    }

    public Matrix oneMinus(Matrix matrix) throws Exception {
        return sub(Matrix.ones(matrix.rows, matrix.cols), matrix);
    }

    public Matrix sub(Matrix matrix, Matrix matrix2) throws Exception {
        return add(matrix, neg(matrix2));
    }

    public Matrix smul(Matrix matrix, double d) throws Exception {
        return elmul(matrix, Matrix.uniform(matrix.rows, matrix.cols, d));
    }

    public Matrix smul(double d, Matrix matrix) throws Exception {
        return smul(matrix, d);
    }

    public Matrix neg(Matrix matrix) throws Exception {
        return elmul(Matrix.negones(matrix.rows, matrix.cols), matrix);
    }

    public Matrix elmul(final Matrix matrix, final Matrix matrix2) throws Exception {
        if (matrix.rows != matrix2.rows || matrix.cols != matrix2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix matrix3 = new Matrix(matrix.rows, matrix.cols);
        for (int i = 0; i < matrix.w.length; i++) {
            matrix3.w[i] = matrix.w[i] * matrix2.w[i];
        }
        if (this.applyBackprop) {
            this.backprop.add(new Runnable() { // from class: recunn.autodiff.Graph.5
                @Override // java.lang.Runnable
                public void run() {
                    for (int i2 = 0; i2 < matrix.w.length; i2++) {
                        double[] dArr = matrix.dw;
                        int i3 = i2;
                        dArr[i3] = dArr[i3] + (matrix2.w[i2] * matrix3.dw[i2]);
                        double[] dArr2 = matrix2.dw;
                        int i4 = i2;
                        dArr2[i4] = dArr2[i4] + (matrix.w[i2] * matrix3.dw[i2]);
                    }
                }
            });
        }
        return matrix3;
    }
}
