aws-neuron / transformers-neuronx

Apache License 2.0
88 stars 25 forks source link

LLaMA fails when the input token length is over 1790 tokens #61

Closed dennj closed 4 months ago

dennj commented 7 months ago

I am trying to use meta-llama/Llama-2-13b-chat-hf witch have a max_position_embeddings of 4096 tokens. I found that the library fails in a non-deterministic way when input length is between 1790 and 1800 tokens. If you insert exactly the same prompt several times you randomly get a good output or a failure. While over the 1800 tokens the failure become more deterministic. However LLaMA with Huggingface transformer library works fine with more than 2000 tokens.

Here a piece of code to reproduce the error. Model preparation:

import transformers

# Version after 4.28.1 save the model with an incompatible format https://github.com/aws-neuron/transformers-neuronx/issues/60
assert transformers.__version__ == "4.28.1", f"Version is {transformers.__version__}"

from transformers import LlamaForCausalLM
import torch
from transformers_neuronx.module import save_pretrained_split

model_name = "meta-llama/Llama-2-13b-chat-hf"
model = LlamaForCausalLM.from_pretrained(model_name)
save_pretrained_split(model, './Llama-2-13b-split')

# Compile the model

import time
import torch
from transformers import AutoTokenizer
from transformers_neuronx.llama.model import LlamaForSampling
import torch_xla.core.xla_model as xm
import os

xla_device_count = len(xm.get_xla_supported_devices())

# load meta-llama/Llama-2-13b to the NeuronCores with N-way tensor parallelism and run compilation
neuron_model = LlamaForSampling.from_pretrained('./Llama-2-13b-split', batch_size=1, tp_degree=xla_device_count, amp='f16')
neuron_model.to_neuron()
neuron_model.save('./neuron_artifacts')
del neuron_model

Reproduce the bug:

# Load compiled model

import random
import string
import torch
from transformers import AutoTokenizer
from transformers import AutoTokenizer
from transformers_neuronx.llama.model import LlamaForSampling
import torch_xla.core.xla_model as xm

neuron_model = LlamaForSampling.from_pretrained('./Llama-2-13b-split', batch_size=1, tp_degree=xla_device_count, amp='f16')
neuron_model.load('./neuron_artifacts')
neuron_model.to_neuron()

model_name = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)

for n_tokens in range(1780,2000):
    all_token_ids = list(tokenizer.get_vocab().values())
    random_token_ids = random.choices(all_token_ids, k=n_tokens)
    random_tokens_tensor = torch.tensor([random_token_ids])

    print(f'''Input with {len(random_tokens_tensor[0])} tokens
    Maximum sequence length for {model_name} is {model.config.max_position_embeddings} tokens''')

    max_output_length = model.config.max_position_embeddings - len(random_tokens_tensor[0])

    with torch.inference_mode():
        start = time.time()
        generated_sequences = neuron_model.sample(random_tokens_tensor, sequence_length=max_output_length, top_k=50)
        elapsed = time.time() - start

    generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]
    print(f'generated sequences {generated_sequences} in {elapsed} seconds')

As I said the bug is not deterministic so the code will fail every time to a different iteration. Here an example:

Input with 1783 tokens
    Maximum sequence length for meta-llama/Llama-2-13b-chat-hf is 4096 tokens
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
Cell In[96], line 21
     19 with torch.inference_mode():
     20     start = time.time()
---> 21     generated_sequences = neuron_model.sample(random_tokens_tensor, sequence_length=max_output_length, top_k=50)
     22     elapsed = time.time() - start
     26 generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/llama/model.py:174, in LlamaForSampling.sample(self, input_ids, sequence_length, start_ids, top_k, top_p, eos_token_override, temperature, streamer)
    171     context_length -= prefixed_length
    172     sequence_length -= prefixed_length
--> 174 result = sampling.sample_llama(
    175     self, input_ids, start_ids, sequence_length,
    176     eos_token_id=self.config.eos_token_id if eos_token_override is None else eos_token_override,
    177     top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer
    178 )
    180 return result

File ~/.local/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/sampling.py:288, in sample_llama(model, input_ids, start_ids, sequence_length, eos_token_id, top_k, top_p, temperature, streamer)
    286 _, start = input_ids.shape
    287 next_token_scores = model(input_ids, None, start_ids)
--> 288 return sample_loop_llama(
    289     model, input_ids, start_ids, next_token_scores, sequence_length, eos_token_id, top_k, top_p, temperature, streamer
    290 )

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/sampling.py:273, in sample_loop_llama(model, input_ids, start_ids, next_token_scores, sequence_length, eos_token_id, top_k, top_p, temperature, streamer)
    271     # forward pass to get next token
    272     cache_ids = torch.as_tensor([cur_len], dtype=torch.int32)
--> 273     next_token_scores = model(inputs, cache_ids, start_ids)
    275 if streamer:
    276     streamer.end()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/llama/model.py:158, in LlamaForSampling.forward(self, input_ids, cache_ids, start_ids)
    156 input_ids, *rst = self._preprocess(input_ids, start_ids=start_ids, cache_ids=cache_ids)  
    157 hidden = self.chkpt_model.model.embed_tokens(input_ids)
--> 158 return self._forward(hidden, *rst)

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/base.py:229, in NeuronModelBase._forward(self, hidden, *args)
    227     logits = self.context(hidden, *args)
    228 else:
--> 229     logits = self.decoder_lm_head(hidden, *args)
    231 logits = logits.to(torch.float32)
    232 logits = logits[:self.config.vocab_size, -1, :]

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/decoder.py:231, in DecoderLmHeadForSamplingNoEmbedding.forward(self, *inputs)
    229 sequence_length = hidden.shape[sequence_dim]
    230 if sequence_length == 1:
--> 231     return self.forward_single(*inputs)
    232 if sequence_length % self.n_active_tokens:
    233     raise ValueError(f'sequence_length={sequence_length} cannot be divided by '
    234                      f'n_active_tokens={self.n_active_tokens}')

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/decoder.py:216, in DecoderLmHeadForSamplingNoEmbedding.forward_single(self, *inputs)
    214 hidden, cache_ids, *_ = inputs
    215 batch_size = hidden.shape[2]
--> 216 bucket_id = self.program.find_bucket_id(cache_ids.item())
    217 if self.use_executor:
    218     return self.program.execute(bucket_id, batch_size, *inputs, return_ranks=self.return_ranks)

File ~/.local/lib/python3.10/site-packages/transformers_neuronx/decoder.py:1043, in DecoderProgram.find_bucket_id(self, length)
   1042 def find_bucket_id(self, length):
-> 1043     return next(idx for idx, npos in enumerate(self.n_positions_list) if npos >= length+1)

StopIteration: 
dennj commented 7 months ago

I could fix the problem by setting n_positions=4096 before compiling the model.

neuron_model = LlamaForSampling.from_pretrained('./Llama-2-13b-split', batch_size=1, tp_degree=xla_device_count, n_positions=4096, amp='f16')

Maybe this flag should be added to the tutorial to avoid other people having to deal with the same problem :)

jyang-aws commented 7 months ago

Thanks @dennj for pointing out the issues. Yes, you can increase the number of positions by setting the n_positions variable. For example, to support up to 4k positions, you could do the following:

neuron_model = LlamaForSampling.from_pretrained(‘./Llama-2-13b-split’, batch_size=1, tp_degree=xla_device_count, amp=‘f16’, n_positions=4096)

We’ll update the documentations accordingly.

dennj commented 7 months ago

Thanks :) Updating the documentation will be helpful.

However the behaviour is non-deterministic, is ok to have a random behaviour in the library?

hannanjgaws commented 7 months ago

By default the tutorial uses top_k=50 sampling, which performs multinomial sampling. If you would like to do deterministic sampling, you can set top_k=1 in your sampling call. This will perform deterministic greedy sampling

aws-donkrets commented 4 months ago

Hi dennj - okay to close this ticket or do you have any other questions?

mrnikwaws commented 4 months ago

Closing since there were no further comments