99991 / SimpleTinyLlama

https://github.com/jzhang38/TinyLlama using only PyTorch
Apache License 2.0
12 stars 1 forks source link

Issue with attention mask unsqueeze in attention weights #4

Open meadewaking opened 8 months ago

meadewaking commented 8 months ago

After completing a batch inference, I discovered a bug in the attention weight computation. The attention mask was being added to the attention weights with an unsqueeze operation that was using the wrong dimension. It should have been unsqueezed along the second dimension instead of the first. Here's the revised and corrected code:

attn_weights = attn_weights + attention_mask[position_ids, :attn_weights.shape[3]].unsqueeze(1)

batch inference code:

```
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = None

model_filename = "data/model.safetensors"
tokenizer_filename = "data/tokenizer.model"

state_dict = load_safetensors(model_filename, device, dtype)
tokenizer = ChatTokenizer(tokenizer_filename)
prompts = ['import numpy as np', '#include<stdio.h>']

token_ids = [tokenizer.encode(i) for i in prompts]

cache = {}
max_text_len = 100
pad_id = -1

token_ids_torch = torch.full((len(token_ids), max_text_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(token_ids):
    token_ids_torch[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prompt_tokens_mask = token_ids_torch != pad_id

# Generate tokens one by one
for cur_pos in range(1, max_text_len):

    inputs = token_ids_torch[:, cur_pos-1:cur_pos]
    position_ids = torch.full(inputs.size(), cur_pos-1)

    # Predict logits
    logits = llama(inputs, position_ids, cache, state_dict)

    # Choose the most likely token
    new_token_ids = logits[:, -1].argmax(-1)

    # Stop if we reach the special end token
    if new_token_ids.any() == tokenizer.end_token_id:
        break

    # Append the new token to the list of tokens
    new_token_ids = torch.where(prompt_tokens_mask[:, cur_pos], token_ids_torch[:, cur_pos], new_token_ids)
    token_ids_torch[:, cur_pos] = new_token_ids

    token_ids = token_ids_torch.cpu().tolist()

response = [tokenizer.decode(i) for i in token_ids]
print(response)
99991 commented 8 months ago

Good catch! And thank you for the correction. I must have forgotten to test with batch size > 1 at some point.