Published on

均值(Mean)​​和​​方差(Variance)

在深度学习和统计学中,**均值(Mean)方差(Variance)**是描述数据分布的核心指标。您代码中的操作是针对张量(Tensor)沿特定维度计算的统计量.


一、概念定义

术语数学定义物理意义
均值 (Mean)mean = sum(x_i) / n数据集中趋势
方差 (Variance)var = sum((x_i - mean)^2) / n数据波动性

二、代码行为解析

1. 输入张量 out 的假设

假设 out 是一个形状为 (batch_size, sequence_length, feature_dim) 的3维张量,例如:

out = torch.randn(2, 3, 4)  # 2个样本,3个时间步,4维特征

2. 计算过程

  • dim=-1:表示沿最后一个维度(即 feature_dim)计算,其他维度(batch_sizesequence_length)保持不变。
  • keepdim=True:保持输出张量的维度数不变(仅压缩计算维度为1)。

3. 输出结果示例

Mean:
 tensor([[[ 0.2], [-0.1], [ 0.4]],  
         [[-0.3], [ 0.0], [ 0.5]]])  # 形状 (2, 3, 1)

Variance:
 tensor([[[ 1.1], [ 0.8], [ 1.3]],  
         [[ 0.9], [ 1.2], [ 0.7]]])  # 形状 (2, 3, 1)

三、计算示意图

out.shape = (2, 3, 4) 为例:

graph TD
    A[out] -->|Shape: 2x3x4| B[Mean/var over dim=-1]
    B --> C[Output Shape: 2x3x1]
  • 实际计算:对每个 (batch, sequence) 位置的4维特征向量独立计算均值和方差。

四、为什么需要均值和方差?

1. 神经网络中的应用

  • 批归一化 (BatchNorm)
    通过标准化(减均值、除方差)加速训练:
    normalized_out = (out - mean) / torch.sqrt(var + eps)
    
  • 注意力机制:分析特征分布的稳定性(如Transformer中的梯度检查)。

2. 数据分析意义

  • 均值:若某时间步的均值接近0,说明特征以零为中心(适合ReLU等激活函数)。
  • 方差:若方差过小(如1e-6),可能引发梯度消失;若过大(如1e6),可能导致数值不稳定。

五、扩展验证

1. 手动计算验证

# 取 out[0,0,:] 的第一个特征向量
feature_vec = out[0, 0, :]  # 形状 (4,)
manual_mean = feature_vec.mean()  # 应与 mean[0,0,0] 一致
manual_var = feature_vec.var(unbiased=False)  # 应与 var[0,0,0] 一致

六、注意事项

  1. 数值稳定性
    方差计算时建议添加极小值 eps 防止除零:
    normalized_out = (out - mean) / torch.sqrt(var + 1e-5)
    
  2. 维度一致性
    若未设置 keepdim=True,输出形状会变为 (2, 3),可能导致后续广播错误。

总结:您代码中的 meanvar 是沿特征维度计算的统计量,常用于数据标准化或分析网络行为。理解它们的计算方式对调试模型和设计架构至关重要。

THE END