- Published on
均值(Mean)和方差(Variance)
在深度学习和统计学中,**均值(Mean)和方差(Variance)**是描述数据分布的核心指标。您代码中的操作是针对张量(Tensor)沿特定维度计算的统计量.
一、概念定义
术语 | 数学定义 | 物理意义 |
---|---|---|
均值 (Mean) | mean = sum(x_i) / n | 数据集中趋势 |
方差 (Variance) | var = sum((x_i - mean)^2) / n | 数据波动性 |
二、代码行为解析
out
的假设
1. 输入张量 假设 out
是一个形状为 (batch_size, sequence_length, feature_dim)
的3维张量,例如:
out = torch.randn(2, 3, 4) # 2个样本,3个时间步,4维特征
2. 计算过程
dim=-1
:表示沿最后一个维度(即feature_dim
)计算,其他维度(batch_size
和sequence_length
)保持不变。keepdim=True
:保持输出张量的维度数不变(仅压缩计算维度为1)。
3. 输出结果示例
Mean:
tensor([[[ 0.2], [-0.1], [ 0.4]],
[[-0.3], [ 0.0], [ 0.5]]]) # 形状 (2, 3, 1)
Variance:
tensor([[[ 1.1], [ 0.8], [ 1.3]],
[[ 0.9], [ 1.2], [ 0.7]]]) # 形状 (2, 3, 1)
三、计算示意图
以 out.shape = (2, 3, 4)
为例:
graph TD
A[out] -->|Shape: 2x3x4| B[Mean/var over dim=-1]
B --> C[Output Shape: 2x3x1]
- 实际计算:对每个
(batch, sequence)
位置的4维特征向量独立计算均值和方差。
四、为什么需要均值和方差?
1. 神经网络中的应用
- 批归一化 (BatchNorm):
通过标准化(减均值、除方差)加速训练:normalized_out = (out - mean) / torch.sqrt(var + eps)
- 注意力机制:分析特征分布的稳定性(如Transformer中的梯度检查)。
2. 数据分析意义
- 均值:若某时间步的均值接近0,说明特征以零为中心(适合ReLU等激活函数)。
- 方差:若方差过小(如1e-6),可能引发梯度消失;若过大(如1e6),可能导致数值不稳定。
五、扩展验证
1. 手动计算验证
# 取 out[0,0,:] 的第一个特征向量
feature_vec = out[0, 0, :] # 形状 (4,)
manual_mean = feature_vec.mean() # 应与 mean[0,0,0] 一致
manual_var = feature_vec.var(unbiased=False) # 应与 var[0,0,0] 一致
六、注意事项
- 数值稳定性:
方差计算时建议添加极小值eps
防止除零:normalized_out = (out - mean) / torch.sqrt(var + 1e-5)
- 维度一致性:
若未设置keepdim=True
,输出形状会变为(2, 3)
,可能导致后续广播错误。
总结:您代码中的 mean
和 var
是沿特征维度计算的统计量,常用于数据标准化或分析网络行为。理解它们的计算方式对调试模型和设计架构至关重要。
THE END