/*
 * Decompiled with CFR 0.152.
 */
package recunn.autodiff;

import java.util.ArrayList;
import java.util.List;
import recunn.matrix.Matrix;
import recunn.model.Nonlinearity;

public class Graph {
    boolean applyBackprop;
    List<Runnable> backprop = new ArrayList<Runnable>();

    public Graph() {
        this.applyBackprop = true;
    }

    public Graph(boolean applyBackprop) {
        this.applyBackprop = applyBackprop;
    }

    public void backward() {
        for (int i = this.backprop.size() - 1; i >= 0; --i) {
            this.backprop.get(i).run();
        }
    }

    public Matrix concatVectors(final Matrix m1, final Matrix m2) throws Exception {
        int i;
        if (m1.cols > 1 || m2.cols > 1) {
            throw new Exception("Expected column vectors");
        }
        final Matrix out = new Matrix(m1.rows + m2.rows);
        int loc = 0;
        for (i = 0; i < m1.w.length; ++i) {
            out.w[loc] = m1.w[i];
            out.dw[loc] = m1.dw[i];
            out.stepCache[loc] = m1.stepCache[i];
            ++loc;
        }
        for (i = 0; i < m2.w.length; ++i) {
            out.w[loc] = m2.w[i];
            out.dw[loc] = m2.dw[i];
            out.stepCache[loc] = m2.stepCache[i];
            ++loc;
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    int i;
                    int loc = 0;
                    for (i = 0; i < m1.w.length; ++i) {
                        m1.w[i] = out.w[loc];
                        m1.dw[i] = out.dw[loc];
                        m1.stepCache[i] = out.stepCache[loc];
                        ++loc;
                    }
                    for (i = 0; i < m2.w.length; ++i) {
                        m2.w[i] = out.w[loc];
                        m2.dw[i] = out.dw[loc];
                        m2.stepCache[i] = out.stepCache[loc];
                        ++loc;
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix nonlin(final Nonlinearity neuron, final Matrix m) throws Exception {
        final Matrix out = new Matrix(m.rows, m.cols);
        final int n = m.w.length;
        for (int i = 0; i < n; ++i) {
            out.w[i] = neuron.forward(m.w[i]);
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    for (int i = 0; i < n; ++i) {
                        int n2 = i;
                        m.dw[n2] = m.dw[n2] + neuron.backward(m.w[i]) * out.dw[i];
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix mul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.cols != m2.rows) {
            throw new Exception("matrix dimension mismatch");
        }
        int m1rows = m1.rows;
        final int m1cols = m1.cols;
        final int m2cols = m2.cols;
        final Matrix out = new Matrix(m1rows, m2cols);
        final int outcols = m2cols;
        for (int i = 0; i < m1rows; ++i) {
            int m1col = m1cols * i;
            for (int j = 0; j < m2cols; ++j) {
                double dot = 0.0;
                for (int k = 0; k < m1cols; ++k) {
                    dot += m1.w[m1col + k] * m2.w[m2cols * k + j];
                }
                out.w[outcols * i + j] = dot;
            }
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    for (int i = 0; i < m1.rows; ++i) {
                        int outcol = outcols * i;
                        for (int j = 0; j < m2.cols; ++j) {
                            double b = out.dw[outcol + j];
                            for (int k = 0; k < m1.cols; ++k) {
                                int n = m1cols * i + k;
                                m1.dw[n] = m1.dw[n] + m2.w[m2cols * k + j] * b;
                                int n2 = m2cols * k + j;
                                m2.dw[n2] = m2.dw[n2] + m1.w[m1cols * i + k] * b;
                            }
                        }
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix add(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix out = new Matrix(m1.rows, m1.cols);
        for (int i = 0; i < m1.w.length; ++i) {
            out.w[i] = m1.w[i] + m2.w[i];
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    for (int i = 0; i < m1.w.length; ++i) {
                        int n = i;
                        m1.dw[n] = m1.dw[n] + out.dw[i];
                        int n2 = i;
                        m2.dw[n2] = m2.dw[n2] + out.dw[i];
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }

    public Matrix oneMinus(Matrix m) throws Exception {
        Matrix ones = Matrix.ones(m.rows, m.cols);
        Matrix out = this.sub(ones, m);
        return out;
    }

    public Matrix sub(Matrix m1, Matrix m2) throws Exception {
        Matrix out = this.add(m1, this.neg(m2));
        return out;
    }

    public Matrix smul(Matrix m, double s) throws Exception {
        Matrix m2 = Matrix.uniform(m.rows, m.cols, s);
        Matrix out = this.elmul(m, m2);
        return out;
    }

    public Matrix smul(double s, Matrix m) throws Exception {
        Matrix out = this.smul(m, s);
        return out;
    }

    public Matrix neg(Matrix m) throws Exception {
        Matrix negones = Matrix.negones(m.rows, m.cols);
        Matrix out = this.elmul(negones, m);
        return out;
    }

    public Matrix elmul(final Matrix m1, final Matrix m2) throws Exception {
        if (m1.rows != m2.rows || m1.cols != m2.cols) {
            throw new Exception("matrix dimension mismatch");
        }
        final Matrix out = new Matrix(m1.rows, m1.cols);
        for (int i = 0; i < m1.w.length; ++i) {
            out.w[i] = m1.w[i] * m2.w[i];
        }
        if (this.applyBackprop) {
            Runnable bp = new Runnable(){

                @Override
                public void run() {
                    for (int i = 0; i < m1.w.length; ++i) {
                        int n = i;
                        m1.dw[n] = m1.dw[n] + m2.w[i] * out.dw[i];
                        int n2 = i;
                        m2.dw[n2] = m2.dw[n2] + m1.w[i] * out.dw[i];
                    }
                }
            };
            this.backprop.add(bp);
        }
        return out;
    }
}

