nlp-with-transformers / notebooks

Jupyter notebooks for the Natural Language Processing with Transformers book
https://transformersbook.com/
Apache License 2.0
3.85k stars 1.19k forks source link

Chapter 5 - Text Generation | Beam Search Decoding - Log Probabilities #67

Open gcmsrc opened 2 years ago

gcmsrc commented 2 years ago

Information

The question or comment is about chapter:

Question or comment

In section Beam Search Decoding of chapter 5, at page 132, the Authors include the following function for calculating a sequence log-probability:

def sequence_logprob(model, labels, input_len=0):
    with torch.no_grad():
        output = model(labels)
        log_probs = log_probs_from_logits(
            output.logits[:, :-1, :], labels[:, 1:])
        seq_log_prob = torch.sum(log_probs[:, input_len:])
    return seq_log_prob.cpu().numpy()

Where labels correspond to output_greedy, calculated as:

max_length = 128
input_txt = """In a shocking finding, scientist discovered \
a herd of unicorns living in a remote, previously unexplored \
valley, in the Andes Mountains. Even more surprising to the \
researchers was the fact that the unicorns spoke perfect English.\n\n
"""
input_ids = tokenizer(input_txt, return_tensors="pt")["input_ids"].to(device)
output_greedy = model.generate(input_ids, max_length=max_length, 
                               do_sample=False)
print(tokenizer.decode(output_greedy.squeeze()))

While it is clear the alignment between logits and labels (i.e., labels are shifted by 1), it is not clear to me why - when calculating the sequence log probability - we slice the log_probs tensor using input_len instead of input_len - 1 (see below)

seq_log_prob = torch.sum(log_probs[:, input_len:])

Let me walk you through the example in the book in details:

When doing a forward pass using the model, i.e., outputs = model(output_greedy) the output (what is then passed as labels in sequence_logprob function) will include logits, whose dimension-1 is 128 (our max_length). We know that logits at index 0 actually refer to what would be the second token in our output sequence. Python 0-indexing is confusing here, but we can say that:

logit at index 0 in outputs.logits corresponds to the logits of the second word in our sequence

In other words, logit index >> (index + 2)th word in the sequence (where sequence is 1-indexed). Following the same reasoning, we know that the first truly model-generated token (i.e., a token not present in the initial prompt) is the 48th word in the sequence, i.e., the (48 - 2)th logit in outputs.logits. We know that the model-generated text starts with The researchers, from the University of California and we can verify it by running:

tokenizer.decode(torch.argmax(output.logits[0,46:52], dim=-1))

image

In other words, the delta between logits indices and word position in output sequence is 2 because of:

As a consequence, the sequence_logprob function should be changed as follows:

def sequence_logprob(model, labels, input_len=0):
    with torch.no_grad():
        output = model(labels)
        log_probs = log_probs_from_logits(
            output.logits[:, :-1, :], labels[:, 1:])
        # CHANGE HERE
        seq_log_prob = torch.sum(log_probs[:, input_len-1:])
    return seq_log_prob.cpu().numpy()

Am I missing something?