Published on

trainable weight self Attention

Why query, key, and value?

The terms key query and value in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar con- cepts are used to store, search, and retrieve information.

A query is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

The key is like a database key used for indexing and searching. In the attention mech- anism, each item in the input sequence (e.g., each word in a sentence) has an asso- ciated key. These keys are used to match the query.

The value in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

import torch

################## x = 2 as query token ###################

### Attention mechanism
# Embeddings tensoer
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your
    [0.55, 0.87, 0.66], # journey
    [0.57, 0.85, 0.64], # starts
    [0.22, 0.58, 0.33], # with
    [0.77, 0.25, 0.10], # one
    [0.05, 0.80, 0.55]] # step
)
print(inputs.shape)

# Query token: Token index=1 as Query token inputs[1]
x_2 = inputs[1]

# The input embedding size, d=3
d_in = inputs.shape[1]

# The output embedding size, d=2
d_out = 2

## initialize the weight matrices W(q), W(k), W(v)
torch.manual_seed(123)  # For reproducibility
# set requires_grad=False to reduce clutter in the outputs,
W_query = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.randn(d_in, d_out), requires_grad=False)

## compute the query, key, and value vectors
query_2 = torch.matmul(x_2, W_query)  # Shape: (1, 2)
key_2 = torch.matmul(x_2, W_key)      # Shape: (1, 2)
value_2 = torch.matmul(x_2, W_value)  # Shape: (1, 2)
print("Query vector:", query_2)

keys = torch.matmul(inputs, W_key)  # Shape: (6, 2)
values = torch.matmul(inputs, W_value)  # Shape: (6, 2)
print("Keys shape:", keys.shape)
print("Values shape:", values.shape)

# Compute attention scores W(22)
keys_2 = keys[1]  # Get the key vector for the query token
attn_scores_22 = query_2.dot(keys_2)  # Shape: (1,)
print("Attention scores:", attn_scores_22)

attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

d_k = keys.shape[-1]  # Dimension of the key vectors
attn_weights_2 = torch.softmax(attn_scores_22 / (d_k ** 0.5), dim=0)  # Shape: (1,)
print("Attention weights:", attn_weights_2)  # tensor(0.8651)

# Compute the context vector
context_vector_2 = attn_weights_2 * values[1]  # Shape: (1, 2) . values -> all 
print("Context vector:", context_vector_2)  

# So far, we’ve only computed a single context vector, z(2).


THE END