在网上有很多讲如何实现Tiled Matrix Multiplication的文章,不过大部分只对方阵且尺寸等于Tile尺寸整倍数的矩阵有效。我在这里贴出实现任意尺寸矩阵乘法的代码。

至于计算方法,主要参考了这篇文章[1]。简单讲就是把每个block看做矩阵P(即输出矩阵)的一块固定大小的正方形Tile,每个thread负责计算这个Tile中的一个元素。当然矩阵P的宽或高未必是Tile尺寸的整倍数,我们要在代码中特别注意这一点。

首先计算我们需要多少个block,即grid的大小:

int dimX = (int)(ceil((float)P.width / TILE_WIDTH));
int dimY = (int)(ceil((float)P.height / TILE_WIDTH));
dim3 dimGrid(dimX, dimY);
dim3 dimBlock(TILE_WIDTH, TILE_WIDTH);

在每个block中(即一个tile中)的所有thread共享一个shared memory,所以为了提高计算效率,我们首先把这个tile需要的数据从两个输入矩阵中拷过来(即从global memory拷贝到shared memory里),存入两个临时矩阵之中。因为shared memory的大小是有限制的,所以要谨慎选择tile的大小,在这里我选择的是16×16的tile。

根据矩阵乘法的原理,每个tile的计算过程如图所示(图片来自于课件,出处请参见图片左下脚注):

tiledmul

在图示的例子中,每个block需要循环三次,分别从两个输入矩阵中从左到右和从上到下各提取三次与tile同样大小的数据块才能完成一个tile的计算。当然这个例子很简单了,如果输入矩阵不是方阵或宽高不是tile尺寸的整倍数,我们也要同样提取,然后在越界元素的位置填充0,以便不影响计算结果。因为一个block中所有thread都是同时运行的,所以这些0元素的拷贝和计算不会拖慢程序的运行。

下面是完整的kernel函数代码:

__global__ void MatrixMulKernel(Matrix M, Matrix N, Matrix P)
{
    __shared__ float sharedM[TILE_WIDTH][TILE_WIDTH];
    __shared__ float sharedN[TILE_WIDTH][TILE_WIDTH];
    int bx = blockIdx.x;
    int by = blockIdx.y;
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int row = by*TILE_WIDTH + ty;
    int col = bx*TILE_WIDTH + tx;
    float v = 0.0;

    for (int i = 0; i < (int)(ceil((float)M.width/TILE_WIDTH)); i++)
    {
        if (i*TILE_WIDTH + tx < M.width && row < M.height)
            sharedM[ty][tx] = M.elements[row*M.width + i*TILE_WIDTH + tx];
        else
            sharedM[ty][tx] = 0.0;

        if (i*TILE_WIDTH + ty < N.height && col < N.width)
            sharedN[ty][tx] = N.elements[(i*TILE_WIDTH + ty)*N.width + col];
        else
            sharedN[ty][tx] = 0.0;
        __syncthreads();

        for(int j = 0; j < TILE_WIDTH; j++)
            v += sharedM[ty][j] * sharedN[j][tx];
        __syncthreads();
    }

    if (row < P.height && col < P.width)
        P.elements[row*P.width + col] = v;
}

参考资料:
[1] CUDA矩阵乘法——利用共享存储器

» 转载请注明来源及链接:未来代码研究所

Related Posts:

2 Responses to “在CUDA中实现任意尺寸的矩阵乘法(Tiled Matrix Multiplication)”

  • Jokeren says:

    lz偷懒啊,这幅图明显就是Wen-Mei Hwu给我们上课时候的图(看脚注),算法也是。
    希望能够标注一下。

    • 暗影吉他手 says:

      额,我以为既然有脚注就不用再标注来源了……不过还是标注一下吧,谢谢指正~
      至于算法,图示的例子只演示了最简单的情况,没有考虑到尺寸非整倍数的情况。kernel函数代码完全是我自己写的。

Leave a Reply

World Line
Time Machine
Online Tools