- Published on
Greedy search decoding
The simplest decoding method to get discrete tokens from a model’s continuous output is to greedily select the token with the highest probability at each timestep.
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device is on {device}")
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2").to(device)
print(f"model info: {model.config}")
print(f"model detailed info: {model}")
import pandas as pd
input_txt = "Transformer are the"
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
iterations = []
n_step = 8
choices_per_step = 5
with torch.no_grad():
for _ in range(n_step):
iteration = dict()
iteration["Input"] = tokenizer.decode(input_ids[0])
output = model(input_ids=input_ids)
# output.logits 是模型输出的所有token的logits(未归一化的分数),形状通常是 [batch_size, sequence_length, vocab_size]。
# [0, -1, :] 的意思是:
# 0:取第一个batch(通常只有一个输入)。
# -1:取当前序列的最后一个token(即下一个要预测的位置)。
# ::取所有词表的logits(即对每个词的分数)。
next_token_logits = output.logits[0, -1, :]
next_token_probs = torch.softmax(next_token_logits, dim = -1)
sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)
for choice_idx in range(choices_per_step):
token_id = sorted_ids[choice_idx]
token_prob = next_token_probs[token_id].cpu().numpy()
token_choice = (f"{tokenizer.decode(token_id)} ({100 * token_prob: .2f}%)")
iteration[f"Choice {choice_idx+1}"] = token_choice
input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
iterations.append(iteration)
df = pd.DataFrame(iterations)
print(df)
# using generate,make sure sampling is switched off!!!
# input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
# output = model.generate(input_ids, max_new_tokens=n_step, do_sample=False)
# print(f"Oputput: {tokenizer.decode(output[0])}")
Output:
Input Choice 1 Choice 2 Choice 3 Choice 4 Choice 5
0 Transformer are the most ( 7.06%) only ( 3.43%) same ( 3.19%) first ( 1.94%) two ( 1.83%)
1 Transformer are the most common ( 28.70%) important ( 7.04%) popular ( 6.11%) powerful ( 5.56%) commonly ( 3.77%)
2 Transformer are the most common types ( 6.38%) type ( 5.11%) and ( 2.27%) components ( 1.37%) , ( 1.37%)
3 Transformer are the most common types of ( 73.70%) in ( 4.82%) . ( 3.56%) for ( 2.48%) , ( 2.01%)
4 Transformer are the most common types of transformer ( 2.17%) data ( 1.44%) trans ( 0.92%) transform ( 0.88%) objects ( 0.65%)
5 Transformer are the most common types of trans... . ( 28.06%) , ( 12.15%) in ( 7.66%) that ( 6.82%) used ( 6.69%)
6 Transformer are the most common types of trans... \n ( 21.92%) They ( 17.46%) The ( 6.09%) These ( 4.07%) In ( 2.66%)
7 Transformer are the most common types of trans... \n ( 99.86%) The ( 0.01%) A ( 0.01%) I ( 0.00%) In ( 0.00%)
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Oputput: Transformer are the most common types of transformer.
```s
THE END