aws-neuron / transformers-neuronx

Apache License 2.0
88 stars 25 forks source link

NaN outputs when masking llama model inputs #79

Open dacorvo opened 4 months ago

dacorvo commented 4 months ago

In previous versions of transformers_neuronx, one could use start_ids to mask inputs during the inference of Llama models.

Now, when specifying anything else than None or 0 when calling the model forward() method returns NaN scores.

Here is an example script to illustrate the issue:

import argparse
import torch
from typing import List
from transformers import AutoTokenizer
from transformers_neuronx.llama.model import LlamaForSampling

def get_model(model_path):
  model = LlamaForSampling.from_pretrained(model_path,
                                           batch_size=2,
                                           amp='f16',
                                           tp_degree=2,
                                           n_positions=2048)
  model.to_neuron()
  return model

def get_padded_inputs(input_lengths: List[int], mask_inputs: bool = True):
  prompt = "It was a bright cold day in April, and the clocks were striking thirteen." \
            " Winston Smith, his chin nuzzled into his breast in an effort to escape the vile wind," \
            " slipped quickly through the glass doors of Victory Mansions, though not quickly enough" \
            " to prevent a swirl of gritty dust from entering along with him."
  t = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
  tokens = t(prompt)["input_ids"]
  input_ids = []
  max_length = max(input_lengths)
  for input_length in input_lengths:
    if input_length > len(tokens):
       raise ValueError(f"Input length should be lower than {len(tokens)}")
    if input_length == max_length:
      ids = tokens[:input_length]
    else:
      ids = [t.eos_token_id,] * (max_length - input_length) + tokens[:input_length]
    input_ids.append(ids)
  input_ids = torch.tensor(input_ids)
  print(f"Using padded inputs of length {input_lengths}.")
  start_ids = None
  if mask_inputs:
    start_ids = torch.argmax((input_ids != t.eos_token_id).to(torch.int64), dim=1)
    print(f"Using masked inputs, starting at offsets: {start_ids}")
  return {"input_ids": input_ids,
          "cache_ids": None,
          "start_ids": start_ids}

def greedy(model, inputs):
    scores = model(**inputs)
    return torch.argmax(scores, dim=-1)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("model_path", type=str)
    parser.add_argument("--input-length", type=int, default=64)
    parser.add_argument("--mask-inputs", action="store_true")
    args= parser.parse_args()
    input_lengths = (args.input_length, args.input_length - 1)
    inputs = get_padded_inputs(input_lengths, mask_inputs=args.mask_inputs)
    model = get_model(args.model_path)
    scores = model(**inputs)
    # Greedy
    tokens = greedy(model, inputs)
    print(tokens)
    assert torch.all(tokens != 0)

Assuming you have saved a Llama checkpoint under <llama-path>, you will get the following results:

$ python test_padded_inputs.py <llama-path> --input-length 64
tensor([14653,   263])
$ python test_padded_inputs.py <llama-path> --input-length 64 --mask-inputs
tensor([0, 0])

Is this expected ? I noticed that the new continuous batching feature interprets start_ids as seq_ids: should start_ids be used only to differentiate between active/inactive sequences from now on, and not to mask inputs ?

If so, how are we suppose to deal with inputs of different length in the same batch ? Without masking, the outputs are complete gibberish.

dacorvo commented 4 months ago

Note that the NaN issue when masking does not happen if the inputs are small enough (less than 15 tokens). It does not happen either with gpt2.

dacorvo commented 4 months ago

Just to give an idea of the consequences for inference.

Considering the input prompt "One of my fondest memory is of my grandmother making homemade bread".

Inference result with twice the same prompt in the same batch, hence no padding:

'<s> One of my fondest memory is of my grandmother making homemade bread. It was a special occasion, like a birthday or holiday',
'<s> One of my fondest memory is of my grandmother making homemade bread. It was a special occasion, like a birthday or holiday'

Inference result with one of the prompts slightly longer (added a .), hence with a 1-token masked padding:

'</s><s> One of my fondest memory is of my grandmother making homemade bread<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>',
'<s> One of my fondest memory is of my grandmother making homemade bread.<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>'

If you do the same thing, but omitting the mask, the results are non-deterministic anymore:

'</s><s> One of my fondest memory is of my grandmother making homemade bread. She would mix the dough by hand, kneading it', 
'<s> One of my fondest memory is of my grandmother making homemade bread. It was a special occasion, like a birthday or holiday'

If you do the same thing with even more padding, then the outputs are complete gibberish:

'</s></s></s></s></s></s></s></s></s></s></s><s> One of my fondest memory is of my grandmother making homemade breadMSMSMS', 
'<s> One of my fondest memory is of my grandmother making homemade bread. It was a special occasion, like a birthday or holiday'
aws-rhsoln commented 4 months ago

Thank you for reporting the issue, we are trying to reproduce on our end. Just to confirm, you are using transformers_neuronx from 2.16 release?

dacorvo commented 4 months ago

Yes.

aws-rhsoln commented 4 months ago

We were able to reproduce the NaNs using the command: python test_padded_inputs.py <llama-path> --input-length 64 --mask-inputs . We are now looking into it.

dacorvo commented 2 months ago

Bump: more urgent now that continuous batching is also broken for Mistral and Mixtral.