- Published on
PyTorch中Broadcasting(广播机制)
Broadcasting(广播机制) 是 PyTorch(和 NumPy)中一种隐式扩展张量维度的规则,它允许不同形状的张量进行逐元素操作(如加减乘除),而无需显式复制数据。其核心思想是:自动扩展较小张量的维度,使其与较大张量的形状兼容。
1. 广播的基本规则
广播遵循两个关键规则:
- 从右向左对齐形状,逐维度比较:
- 如果两个维度大小相同,或其中一个为 1,则兼容。
- 如果其中一个维度缺失(即张量维度数不同),则在左侧补 1。
- 在兼容的维度上,将大小为 1 的维度复制扩展为较大张量的对应维度大小。
2. 经典示例
示例 1:标量与张量相加
import torch
a = torch.tensor([1, 2, 3]) # shape: (3,)
b = 2 # 标量(视为 shape: ())
c = a + b # 广播:b 被扩展为 [2, 2, 2]
print(c) # 输出: tensor([3, 4, 5])
示例 2:不同形状的矩阵操作
A = torch.tensor([[1, 2], # shape: (2, 2)
[3, 4]])
B = torch.tensor([10, 20]) # shape: (2,)
C = A + B # 广播:B 被扩展为 [[10, 20], [10, 20]]
print(C)
输出:
tensor([[11, 22],
[13, 24]])
A + B
为例)
3. 广播的实际步骤(以 - 对齐形状:
A.shape = (2, 2)
B.shape = (2,)
→ 补 1 变为(1, 2)
- 扩展维度:
B
的第 0 维从 1 复制到 2,变为(2, 2)
- 逐元素相加:
A: [[1, 2], B(扩展后): [[10, 20], Result: [[11, 22], [3, 4]] [10, 20]] [13, 24]]
4. 广播的常见用途
场景 | 示例代码 | 广播效果 |
---|---|---|
归一化 | x - x.mean(dim=0) | 均值向量扩展匹配 x 的形状 |
权重乘法 | features * weights[None, :] | 插入 None 扩展维度 |
批量操作 | batch_data + bias | 偏置广播到所有样本 |
5. 广播的限制与错误
合法广播
A = torch.rand(3, 1, 4) # (3, 1, 4)
B = torch.rand( 2, 1) # ( 2, 1)
C = A + B # 广播后: (3, 2, 4)
非法广播(报错)
A = torch.rand(3, 4)
B = torch.rand(2, 4)
C = A + B # 报错:维度不匹配 (3,4) 和 (2,4) 的第0维既不相同也无1
6. 手动控制广播
通过 unsqueeze
或 reshape
显式扩展维度:
B = torch.tensor([10, 20])
B_expanded = B.unsqueeze(0) # shape: (1, 2)
B_final = B_expanded.expand(2, 2) # 显式复制为 (2, 2)
7. 性能注意事项
- 优点:避免显式复制数据,节省内存。
- 缺点:过度依赖广播可能导致计算图优化困难(尤其在动态图中)。
总结
广播机制通过隐式扩展张量维度,让不同形状的张量能直接运算。其核心是:
- 形状对齐(从右向左补 1)。
- 维度复制(大小为 1 的维度扩展)。
- 逐元素操作。
THE END