- Published on
Single Head Self-Attention
- 代码 + 注释说明
import torch
import torch.nn as nn
# single head self-attention
# 1. 投影 (Projection): 使用三个独立的线性层 (W_query, W_key, W_value) 将输入序列 x 投影到查询(Query)、键(Key)和值(Value)三个不同的空间。
# 2. 计算注意力分数 (Attention Scores): 通过矩阵乘法 queries @ keys.transpose(-2, -1) 计算出每个查询与所有键之间的相似度。
# 3. 缩放 (Scaling): 将分数除以一个缩放因子 sqrt(d_k),以防止梯度在 softmax 后变得过小。
# 4. 归一化 (Normalization): 应用 softmax 函数将分数转换成概率分布(即注意力权重 attn_weights),其和为 1。
# 5. 加权求和 (Weighted Sum): 将注意力权重与 values 向量相乘,得到最终的上下文向量 context_vector。
# 自注意力的核心是计算 每个 Query 对所有 Key 的相似度
# 自注意力的核心思想是:对于序列中的每一个词(token),我们想计算它与序列中所有其他词(包括它自己)的关联程度或“注意力分数”。这个分数是通过计算代表当前词的 Query 向量与代表其他词的 Key 向量之间的点积得到的.
# Q的形状:[num_queries, d_k]
# K.T的形状:[d_k, num_keys]
# 结果 attention_scores:[num_queries, num_keys]
# d_in 是输入特征的维度(embedding size)
# d_out 是输出特征的维度(通常是更小的维度,用于压缩信息)
# qkv_bias 是一个布尔值,表示是否为查询、键和值的线性变换添加偏置项。
class SelfAttention(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
# 一个 nn.Linear 实例就是一个全连接层。
# 它们的作用是将同一个输入 x 投影到三个不同的、语义上有特定含义的子空间中.
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# Combine Q, K, V projections into a single linear layer for efficiency
# self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
# 简单类比
# 你可以把这个过程想象成去图书馆查资料:
# x 是你脑中模糊的“原始想法”。
# self.W_key 是一个“转换器”,它帮你把“原始想法”转换成图书馆能理解的**“关键词”**(keys)。
# 同样,你也会把“原始想法”转换成你要问的问题,即**“查询”**(queries)。
# 接下来,你就可以用你的“查询”去匹配书架上所有书的“关键词”,看看哪些书与你的问题最相关。
def forward(self, x):
# 它的作用是:通过一个可学习的线性变换(权重矩阵 W_key),将输入的词嵌入序列 x 转换(或投影)成 “键”(Key)向量序列。
keys = self.W_key(x) # Shape: (batch_size, seq_len, d_out)
# 含义: 代表了序列中每个词的“内容”或“信息”。
values = self.W_value(x) # Shape: (batch_size, seq_len, d_out)
queries = self.W_query(x) # Shape: (batch_size, seq_len, d_out)
# Split the combined tensor into Q, K, and V
# queries, keys, values = torch.chunk(qkv, 3, dim=-1)
# 矩阵乘法: 前一个张量的 最后一个维度 必须等于后一个张量的 倒数第二个维度 ,才能进行类似矩阵乘法的运算 。
attn_scores = queries @ keys.transpose(-2, -1) # More robust for batched inputs
print(f"Attention scores shape: ", attn_scores.shape)
attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
# context_vector 是对序列中所有词的 "Value" 向量进行加权求和的结果,而这个 "权重" 就是我们前面计算出的注意力权重 attn_weights。
# context_vector 的关键在于它为序列中的每个词生成了一个新的、带有上下文信息的嵌入(embedding)。
# 之前: 输入 x 中的每个词嵌入是孤立的,它只代表那个词本身。
# 之后: context_vector 中的每个词嵌入,是通过聚合序列中所有词的信息(values)而形成的。聚合的程度由注意力权重(attn_weights)决定。如果一个词与当前词高度相关,它的 value 就会在加权求和中占有更大的比重
context_vector = attn_weights @ values # Shape: (batch_size, seq_len, d_out)
return context_vector
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your
[0.55, 0.87, 0.66], # journey
[0.57, 0.85, 0.64], # starts
[0.22, 0.58, 0.33], # with
[0.77, 0.25, 0.10], # one
[0.05, 0.80, 0.55]] # step
)
# Add a batch dimension for a more realistic scenario
inputs = inputs.unsqueeze(0) # Shape: (1, 6, 3)
# The input embedding size, d=3
d_in = inputs.shape[-1] # Should be the last dimension (features)
# The output embedding size, d=2
d_out = 2
torch.manual_seed(789)
sa_v2 = SelfAttention(d_in, d_out)
output = sa_v2(inputs) # This calls the forward method
print("Output shape:", output.shape)
print("Output tensor:\n", output)
# Note:
# About sa_v2(inputs)
# 这个调用实际上会执行 SelfAttention 类中定义的 forward 方法。
# 这是一种在 PyTorch 中非常常见的用法。让我为你详细解释一下:
# sa_v2 是 SelfAttention 类的一个实例。
# SelfAttention 类继承自 torch.nn.Module。
# 在 Python 中,当一个类的实例被像函数一样调用时(例如 instance(argument)),Python 会去寻找并执行该类的 __call__ 特殊方法。
# torch.nn.Module 基类已经定义了 __call__ 方法。这个方法会做一些内部处理(比如注册钩子),然后它会调用你在自己类中定义的 forward 方法。
# 所以,完整的调用流程是: sa_v2(inputs) -> SelfAttention.__call__(inputs) (从 nn.Module 继承而来) -> SelfAttention.forward(inputs) (你自己定义的)。
# 直接调用 sa_v2.forward(inputs) 也能工作,但官方推荐使用 sa_v2(inputs) 的方式,因为这样可以确保所有在 __call__ 中定义的 PyTorch 内部钩子和机制都能被正确执行。
THE END