/*
 * Decompiled with CFR 0.152.
 */
package org.ujmp.core.doublematrix.impl;

import java.util.concurrent.Callable;
import org.ujmp.core.doublematrix.impl.BlockDenseDoubleMatrix2D;
import org.ujmp.core.util.VerifyUtil;

public class BlockMultiply
implements Callable<Void> {
    private final int blockStripeSize;
    private final int fromM;
    private final int toM;
    private final int fromN;
    private final int toN;
    private final int fromK;
    private final int toK;
    private final BlockDenseDoubleMatrix2D matrixA;
    private final BlockDenseDoubleMatrix2D matrixB;
    private final BlockDenseDoubleMatrix2D matrixC;

    public BlockMultiply(BlockDenseDoubleMatrix2D a, BlockDenseDoubleMatrix2D b, BlockDenseDoubleMatrix2D c, int fromM, int toM, int fromN, int toN, int fromK, int toK) {
        BlockMultiply.verifyInput(a, b, c, fromM, toM, fromN, toN, fromK, toK);
        this.matrixA = a;
        this.matrixB = b;
        this.matrixC = c;
        this.fromM = fromM;
        this.toM = toM;
        this.fromN = fromN;
        this.toN = toN;
        this.fromK = fromK;
        this.toK = toK;
        this.blockStripeSize = a.layout.blockStripe;
    }

    @Override
    public Void call() {
        this.multiply();
        return null;
    }

    protected final void multiply() {
        int step = this.blockStripeSize;
        int blockSize = this.blockStripeSize * this.blockStripeSize;
        for (int m = this.fromM; m < this.toM; m += step) {
            int aRows = this.matrixA.layout.getRowsInBlock(m);
            for (int k = this.fromK; k < this.toK; k += step) {
                int bCols = this.matrixB.layout.getColumnsInBlock(k);
                double[] cBlock = new double[aRows * bCols];
                for (int n = this.fromN; n < this.toN; n += step) {
                    double[] aBlock = this.matrixA.layout.toRowMajorBlock(this.matrixA, m, n);
                    double[] bBlock = this.matrixB.layout.toColMajorBlock(this.matrixB, n, k);
                    if (aBlock == null || bBlock == null) continue;
                    if (aBlock.length == blockSize && bBlock.length == blockSize) {
                        BlockMultiply.multiplyAxB(aBlock, bBlock, cBlock, step);
                        continue;
                    }
                    int aCols = aBlock.length / aRows;
                    int bRows = bBlock.length / bCols;
                    VerifyUtil.verifyTrue(aCols == bRows, "aCols!=bRows");
                    this.multiplyRowMajorTimesColumnMajorBlocks(aBlock, bBlock, cBlock, aRows, aCols, bCols);
                }
                this.matrixC.addBlockData(m, k, cBlock);
            }
        }
    }

    private static void multiplyAxB(double[] aBlock, double[] bBlock, double[] cBlock, int step) {
        int blockStripeMini = step % 3;
        int blockStripeMaxi = step / 3;
        int blockArea = step * step;
        for (int iL = 0; iL < blockArea; iL += step) {
            int rc = iL;
            for (int kL = 0; kL < blockArea; kL += step) {
                int ra = iL;
                int rb = kL;
                double sum = 0.0;
                int jL = blockStripeMini;
                while (--jL >= 0) {
                    sum += aBlock[ra++] * bBlock[rb++];
                }
                jL = blockStripeMaxi;
                while (--jL >= 0) {
                    sum += aBlock[ra++] * bBlock[rb++] + aBlock[ra++] * bBlock[rb++] + aBlock[ra++] * bBlock[rb++];
                }
                int n = rc++;
                cBlock[n] = cBlock[n] + sum;
            }
        }
    }

    public void multiplyRowMajorTimesColumnMajorBlocks(double[] aBlock, double[] bBlock, double[] cBlock, int aRows, int bRows, int bCols) {
        int aCols = bRows;
        for (int i = 0; i < aRows; ++i) {
            for (int k = 0; k < bCols; ++k) {
                double sum = 0.0;
                for (int j = 0; j < bRows; ++j) {
                    sum += aBlock[i * aCols + j] * bBlock[k * bRows + j];
                }
                int n = i * bCols + k;
                cBlock[n] = cBlock[n] + sum;
            }
        }
    }

    private static void verifyInput(BlockDenseDoubleMatrix2D a, BlockDenseDoubleMatrix2D b, BlockDenseDoubleMatrix2D c, int fromM, int toM, int fromN, int toN, int fromK, int toK) {
        VerifyUtil.verifyTrue(a != null, "a cannot be null");
        VerifyUtil.verifyTrue(b != null, "b cannot be null");
        VerifyUtil.verifyTrue(c != null, "c cannot be null");
        VerifyUtil.verifyTrue((long)fromM <= a.getRowCount() && fromM >= 0, "Invalid argument : fromM");
        VerifyUtil.verifyTrue((long)toM <= a.getRowCount() && toM >= fromM, "Invalid argument : fromM/toM");
        VerifyUtil.verifyTrue((long)fromN <= a.getColumnCount() && fromN >= 0, "Invalid argument : fromN");
        VerifyUtil.verifyTrue((long)toN <= a.getColumnCount() && toN >= fromN, "Invalid argument : fromN/toN");
        VerifyUtil.verifyTrue((long)fromK <= b.getColumnCount() && fromK >= 0, "Invalid argument : fromK");
        VerifyUtil.verifyTrue((long)toK <= b.getColumnCount() && toK >= fromK, "Invalid argument : fromK/toK");
        VerifyUtil.verifyTrue(a.getColumnCount() == b.getRowCount(), "Invalid argument : a.columns != b.rows");
        VerifyUtil.verifyTrue(a.getRowCount() == c.getRowCount(), "Invalid argument : a.rows != c.rows");
        VerifyUtil.verifyTrue(b.getColumnCount() == c.getColumnCount(), "Invalid argument : b.columns != c.columns");
    }
}

