- Published on
General Matrix-Matrix Multiplication
通用矩阵乘法(GEMM)详解
GEMM(General Matrix-Matrix Multiplication,通用矩阵乘法)是线性代数中的核心运算,定义为:
Where:
- (size ) and (size ) are input matrices
- (size ) is the input/output matrix
- and are scalar coefficients
- denotes matrix multiplication
- denotes element-wise addition
GEMM 的核心特点
基础性
- 是 BLAS(Basic Linear Algebra Subprograms) 中的 Level 3 运算,高性能计算(HPC)的基石。
- 深度学习中的关键操作(如全连接层、注意力机制、卷积的等效GEMM实现)。
高性能优化
- 在GPU(如NVIDIA)上通过以下方式优化:
- 分块计算(Tiling):将矩阵拆分为小块,适配寄存器/共享内存。
- Tensor Core加速(SM 7.0+):支持混合精度(FP16/BF16/INT8)。
- 典型库:CUTLASS、cuBLAS、OpenBLAS。
- 在GPU(如NVIDIA)上通过以下方式优化:
扩展变种
- Batched GEMM:批量处理多个矩阵(如Transformer中的多头注意力)。
- Strided GEMM:处理非连续内存布局(如矩阵的转置或子视图)。
GEMM 的计算过程
以 为例(假设 , ):
- 遍历 的行()和 的列()
- 对每个 ,计算 的第 行与 的第 列的点积:
代码示例(CUTLASS 实现)
#include <cutlass/gemm/device/gemm.h>
// 定义GEMM运算(float类型,行优先存储)
using Gemm = cutlass::gemm::device::Gemm<
float, cutlass::layout::RowMajor, // 矩阵A的布局
float, cutlass::layout::RowMajor, // 矩阵B的布局
float, cutlass::layout::RowMajor // 矩阵C的布局
>;
float alpha = 1.0f, beta = 0.0f;
Gemm gemm_op;
// 执行GEMM:C = alpha * A * B + beta * C
gemm_op({m, n, k}, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc);
THE END