- Published on
一个简单的Self-Attention机制的实现
自注意力机制(Self-Attention)是Transformer模型的核心组件之一。它允许模型在处理输入序列时,动态地关注序列中的不同部分,从而捕捉长距离依赖关系。
设计到的概念包括:
- Query、Key 和 Value:每个输入token都会生成这三个向量。
- Attention Scores:通过计算Query和Key之间的点积来获得。
- Attention Weights:对Attention Scores进行Softmax归一化,
- Context Vector:通过将Attention Weights应用于Value向量来生成。
下面是一个简单的自注意力机制实现示例,使用PyTorch来计算自注意力。
import torch
import torch.nn as nn
###################################### 计算自注意力 ######################################
# query token inputs[1] as Query token
### Attention mechanism
# Embeddings tensoer
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your
[0.55, 0.87, 0.66], # journey
[0.57, 0.85, 0.64], # starts
[0.22, 0.58, 0.33], # with
[0.77, 0.25, 0.10], # one
[0.05, 0.80, 0.55]] # step
)
print(inputs.shape)
# Query token: Token index=1 as Query token inputs[1]
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
attn_scores_2[i] = torch.dot(query, x_i)
# Attention Score: embedding elements [dot product] of the query token with all input tokens
print(attn_scores_2)
# This is a more efficient way to compute the dot-product attention scores
# for a single query against all inputs.
attn_scores_2 = torch.matmul(inputs, query)
# attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
# print("Attention weights:", attn_weights_2_tmp)
# print("Sum:", attn_weights_2_tmp.sum())
# Attention weights: Softmax normalization: Convert attention scores to Attention weights.
# 每个元素 σ(z) 是一个介于 0 和 1 之间的概率值。
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())
# Context vector: Weighted sum of the input embeddings using the attention weights.
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
context_vec_2 += attn_weights_2[i] * x_i
# tensor([0.4419, 0.6515, 0.5683])
print(context_vec_2)
# The for-loop can be replaced by a more efficient matrix-vector multiplication.
context_vec_2 = torch.matmul(attn_weights_2, inputs)
print(context_vec_2)
###################################### 计算自注意力 ######################################
# Compute self-attention for all tokens in the input sequence
# All inputs
attn_scores = inputs @ inputs.T
print(attn_scores)
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)
THE END