Published on

一个简单的Self-Attention机制的实现

自注意力机制(Self-Attention)是Transformer模型的核心组件之一。它允许模型在处理输入序列时,动态地关注序列中的不同部分,从而捕捉长距离依赖关系。

设计到的概念包括:

  • QueryKeyValue:每个输入token都会生成这三个向量。
  • Attention Scores:通过计算Query和Key之间的点积来获得。
  • Attention Weights:对Attention Scores进行Softmax归一化,
  • Context Vector:通过将Attention Weights应用于Value向量来生成。

下面是一个简单的自注意力机制实现示例,使用PyTorch来计算自注意力。

import torch
import torch.nn as nn

###################################### 计算自注意力 ######################################
# query token inputs[1] as Query token

### Attention mechanism
# Embeddings tensoer
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55]] # step
)
print(inputs.shape)

# Query token: Token index=1 as Query token inputs[1]
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(query, x_i)

# Attention Score: embedding elements [dot product] of the query token with all input tokens
print(attn_scores_2)

# This is a more efficient way to compute the dot-product attention scores
# for a single query against all inputs.
attn_scores_2 = torch.matmul(inputs, query)

# attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
# print("Attention weights:", attn_weights_2_tmp)
# print("Sum:", attn_weights_2_tmp.sum())

# Attention weights: Softmax normalization: Convert attention scores to Attention weights.
# 每个元素 σ(z) 是一个介于 0 和 1 之间的概率值。
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

# Context vector: Weighted sum of the input embeddings using the attention weights.
context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

# tensor([0.4419, 0.6515, 0.5683])
print(context_vec_2)

# The for-loop can be replaced by a more efficient matrix-vector multiplication.
context_vec_2 = torch.matmul(attn_weights_2, inputs)
print(context_vec_2)


###################################### 计算自注意力 ######################################
# Compute self-attention for all tokens in the input sequence

# All inputs
attn_scores = inputs @ inputs.T
print(attn_scores)

attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

THE END