Published on

PyTorch中Broadcasting(广播机制)

Broadcasting(广播机制) 是 PyTorch(和 NumPy)中一种隐式扩展张量维度的规则,它允许不同形状的张量进行逐元素操作(如加减乘除),而无需显式复制数据。其核心思想是:自动扩展较小张量的维度,使其与较大张量的形状兼容


1. 广播的基本规则

广播遵循两个关键规则:

  1. 从右向左对齐形状,逐维度比较:
    • 如果两个维度大小相同,或其中一个为 1,则兼容。
    • 如果其中一个维度缺失(即张量维度数不同),则在左侧补 1
  2. 在兼容的维度上,将大小为 1 的维度复制扩展为较大张量的对应维度大小。

2. 经典示例

示例 1:标量与张量相加

import torch

a = torch.tensor([1, 2, 3])  # shape: (3,)
b = 2                        # 标量(视为 shape: ())
c = a + b                    # 广播:b 被扩展为 [2, 2, 2]
print(c)                     # 输出: tensor([3, 4, 5])

示例 2:不同形状的矩阵操作

A = torch.tensor([[1, 2],    # shape: (2, 2)
                  [3, 4]])
B = torch.tensor([10, 20])   # shape: (2,)
C = A + B                    # 广播:B 被扩展为 [[10, 20], [10, 20]]
print(C)

输出:

tensor([[11, 22],
        [13, 24]])

3. 广播的实际步骤(以 A + B 为例)

  1. 对齐形状
    • A.shape = (2, 2)
    • B.shape = (2,) → 补 1 变为 (1, 2)
  2. 扩展维度
    • B 的第 0 维从 1 复制到 2,变为 (2, 2)
  3. 逐元素相加
    A: [[1, 2],   B(扩展后): [[10, 20],   Result: [[11, 22],
         [3, 4]]               [10, 20]]            [13, 24]]
    

4. 广播的常见用途

场景示例代码广播效果
归一化x - x.mean(dim=0)均值向量扩展匹配 x 的形状
权重乘法features * weights[None, :]插入 None 扩展维度
批量操作batch_data + bias偏置广播到所有样本

5. 广播的限制与错误

合法广播

A = torch.rand(3, 1, 4)  # (3, 1, 4)
B = torch.rand(   2, 1)   # (   2, 1)
C = A + B                 # 广播后: (3, 2, 4)

非法广播(报错)

A = torch.rand(3, 4)
B = torch.rand(2, 4)
C = A + B  # 报错:维度不匹配 (3,4) 和 (2,4) 的第0维既不相同也无1

6. 手动控制广播

通过 unsqueezereshape 显式扩展维度:

B = torch.tensor([10, 20])
B_expanded = B.unsqueeze(0)  # shape: (1, 2)
B_final = B_expanded.expand(2, 2)  # 显式复制为 (2, 2)

7. 性能注意事项

  • 优点:避免显式复制数据,节省内存。
  • 缺点:过度依赖广播可能导致计算图优化困难(尤其在动态图中)。

总结

广播机制通过隐式扩展张量维度,让不同形状的张量能直接运算。其核心是:

  1. 形状对齐(从右向左补 1)。
  2. 维度复制(大小为 1 的维度扩展)。
  3. 逐元素操作

THE END