- Published on
trainable weight self Attention
Why query, key, and value?
The terms key
query
and value
in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar con- cepts are used to store, search, and retrieve information.
A query
is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.
The key
is like a database key used for indexing and searching. In the attention mech- anism, each item in the input sequence (e.g., each word in a sentence) has an asso- ciated key. These keys are used to match the query.
The value
in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.
import torch
################## x = 2 as query token ###################
### Attention mechanism
# Embeddings tensoer
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your
[0.55, 0.87, 0.66], # journey
[0.57, 0.85, 0.64], # starts
[0.22, 0.58, 0.33], # with
[0.77, 0.25, 0.10], # one
[0.05, 0.80, 0.55]] # step
)
print(inputs.shape)
# Query token: Token index=1 as Query token inputs[1]
x_2 = inputs[1]
# The input embedding size, d=3
d_in = inputs.shape[1]
# The output embedding size, d=2
d_out = 2
## initialize the weight matrices W(q), W(k), W(v)
torch.manual_seed(123) # For reproducibility
# set requires_grad=False to reduce clutter in the outputs,
W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
## compute the query, key, and value vectors
query_2 = torch.matmul(x_2, W_query) # Shape: (1, 2)
key_2 = torch.matmul(x_2, W_key) # Shape: (1, 2)
value_2 = torch.matmul(x_2, W_value) # Shape: (1, 2)
print("Query vector:", query_2)
keys = torch.matmul(inputs, W_key) # Shape: (6, 2)
values = torch.matmul(inputs, W_value) # Shape: (6, 2)
print("Keys shape:", keys.shape)
print("Values shape:", values.shape)
# Compute attention scores W(22)
keys_2 = keys[1] # Get the key vector for the query token
attn_scores_22 = query_2.dot(keys_2) # Shape: (1,)
print("Attention scores:", attn_scores_22)
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)
d_k = keys.shape[-1] # Dimension of the key vectors
attn_weights_2 = torch.softmax(attn_scores_22 / (d_k ** 0.5), dim=0) # Shape: (1,)
print("Attention weights:", attn_weights_2) # tensor(0.8651)
# Compute the context vector
context_vector_2 = attn_weights_2 * values[1] # Shape: (1, 2) . values -> all
print("Context vector:", context_vector_2)
# So far, we’ve only computed a single context vector, z(2).
THE END