Published on

General Matrix-Matrix Multiplication

通用矩阵乘法(GEMM)详解

GEMMGeneral Matrix-Matrix Multiplication,通用矩阵乘法)是线性代数中的核心运算,定义为:

C=αA×B+βCC = \alpha \cdot A \times B + \beta \cdot C

Where:

  • AA (size m×km \times k) and BB (size k×nk \times n) are input matrices
  • CC (size m×nm \times n) is the input/output matrix
  • α\alpha and β\beta are scalar coefficients
  • ×\times denotes matrix multiplication
  • ++ denotes element-wise addition

GEMM 的核心特点

  1. 基础性

    • BLAS(Basic Linear Algebra Subprograms) 中的 Level 3 运算,高性能计算(HPC)的基石。
    • 深度学习中的关键操作(如全连接层、注意力机制、卷积的等效GEMM实现)。
  2. 高性能优化

    • 在GPU(如NVIDIA)上通过以下方式优化:
      • 分块计算(Tiling):将矩阵拆分为小块,适配寄存器/共享内存。
      • Tensor Core加速(SM 7.0+):支持混合精度(FP16/BF16/INT8)。
    • 典型库:CUTLASS、cuBLAS、OpenBLAS。
  3. 扩展变种

    • Batched GEMM:批量处理多个矩阵(如Transformer中的多头注意力)。
    • Strided GEMM:处理非连续内存布局(如矩阵的转置或子视图)。

GEMM 的计算过程

C=A×BC = A \times B 为例(假设 α=1\alpha=1, β=0\beta=0):

  1. 遍历 AA 的行(ii)和 BB 的列(jj
  2. 对每个 CijC_{ij},计算 AA 的第 ii 行与 BB 的第 jj 列的点积 Cij=l=1kAilBljC_{ij} = \sum_{l=1}^{k} A_{il} \cdot B_{lj}

代码示例(CUTLASS 实现)

#include <cutlass/gemm/device/gemm.h>

// 定义GEMM运算(float类型,行优先存储)
using Gemm = cutlass::gemm::device::Gemm<
    float, cutlass::layout::RowMajor,  // 矩阵A的布局
    float, cutlass::layout::RowMajor,  // 矩阵B的布局
    float, cutlass::layout::RowMajor   // 矩阵C的布局
>;

float alpha = 1.0f, beta = 0.0f;
Gemm gemm_op;

// 执行GEMM:C = alpha * A * B + beta * C
gemm_op({m, n, k}, alpha, d_A, lda, d_B, ldb, beta, d_C, ldc);

THE END