aws-neuron / aws-neuron-samples

Example code for AWS Neuron SDK developers building inference and training applications
Other
101 stars 32 forks source link

meta llama2 13b sampling notebook example error with longer prompt #39

Closed yapweiyih closed 9 months ago

yapweiyih commented 9 months ago

Model used: Llama-2-13b-chat-hf

Successfully ran the prompt in notebook example:

prompt = "Hello, I'm a language model,"

input_ids = tokenizer.encode(prompt, return_tensors="pt")

# run inference with top-k sampling
with torch.inference_mode():
    start = time.time()
    generated_sequences = neuron_model.sample(input_ids, sequence_length=2048, top_k=50)
    elapsed = time.time() - start

generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]
print(f'generated sequences {generated_sequences} in {elapsed} seconds')

By it failed by just replacing prompt with a longer text (updated seqence_length to 4096 also gave the same result):

LONG = """Summarize the text below:
---
EXTENDING CONTEXT WINDOW OF LARGE LAN-
GUAGE MODELS VIA POSITION INTERPOLATION

Shouyuan Chen Sherman Wong Liangjian Chen  Yuandong Tian
Meta Platforms Inc.
{chenshouyuan, shermanwong, cli, yuandong}@meta . com

1 INTRODUCTION

Large language models (LLMs) typically come with a pre-defined context window size. For exam-
ple, inputs to LLaMA models (Touvron et al., 2023) must be fewer than 2048 tokens. This pre-set
context window limit is frequently exceeded in applications such as conducting long conversations,
summarizing long documents, or executing long-term planning. For these applications, LLMs with
longer context windows are preferred. However, training an LLM from scratch with long context
windows requires significant investments. This naturally leads to a question: Can we extend the
context window of an existing pre-trained LLM?

One straightforward approach is to fine-tune an existing pre-trained Transformer with a longer con-
text window. However, empirically, we found that models trained this way adapt to long context
windows very slowly. After training for more than 10000 batches, the effective context window
saw a minimal increase, moving from 2048 to 2560 (Table 4). This suggests that such method is
inefficient for extending to substantially longer context windows.

While certain techniques such as ALiBi (Press et al., 2022) and LeX (Sun et al., 2022) enable length
extrapolation of Transformers, i.e. train on short context windows and inference on longer ones,
many existing pre-trained LLMs, including LLaMA (Touvron et al., 2023), use positional encodings
that have weak extrapolation properties (e.g., RoPE (Su et al., 2021)). Therefore, the applicability
of these techniques for extending the context window sizes of such LLMs remains limited.

In this work, we introduce Position Interpolation to enable context window extensions for certain
existing pre-trained LLMs, including LLaMA. The key idea is, instead of extrapolation, we directly
down-scale the position indices so that the maximum position index matches the previous context
window limit in the pre-training stage. See Figure 1 for an illustration. In other words, to accom-
modate more input tokens, we interpolate the position encodings at neighboring integer positions,
utilizing the fact that position encodings can be applied on non-integer positions, as opposed to
extrapolating outside the trained positions, which may lead to catastrophic values. We verify our
approach theoretically, by showing that the interpolated attention score has a much smaller upper

bound (~ 600x smaller in LLaMA 7B setting) than the extrapolated one, and is thus much more
stable. Therefore, interpolated position encodings are easier for the model to adapt.

Empirically, we found that Position Interpolation is highly effective and efficient, requiring only a
very short period of fine-tuning for the model to fully adapt to greatly extended context windows.
We present experimental results for extending the context window to up to 32768 from the initial
2048 across 7B to 65B LLaMA models using Position Interpolation. Our results show that

1. Position Interpolation can easily enable very long context windows (e.g. 32768), requiring
only fine-tuning for 1000 steps on the Pile (Gao et al., 2020) to achieve a good quality.
The cost of fine-tuning is negligible compared to the pre-training costs. This confirms
our hypothesis that it is relatively easy for the models to adapt to interpolated position
encodings.

2. Position Interpolation generates strong models that can effectively make use of much ex-
tended context window. We show that models extended by Position Interpolation enjoy
significant perplexity gains from greatly extended context windows for text modeling, and
we show that the perplexity reduces graceful with the enlargement of context windows.
We also applied Position Interpolation in a long text summarization task, and demonstrate
competitive performances.

3. Position Interpolation preserves model quality relatively well for tasks within its original
context window sizes. We present a variety of evaluation results for the extended LLaMA
models on the original LLaMA benchmark. Compared with original LLaMA models, the
extended LLLaM A models saw a minor degradation on several standard benchmarks within
a 2048 token limit.

Our results highlight the innate ability of Transformer models to “extrapolate to sequence lengths
longer than the ones encountered during training” as hypothesized in the seminal work of Vaswani
et al. (2017). We reaffirm this hypothesis and suggest that the previously known weakness of ex-
trapolating to longer sequences for language modeling (Press et al., 2022) may be due to direct

extrapolation of positional encodings and it can be largely mitigated by interpolating position en-
codings instead.

Concurrent work. Right before our release, we are informed with a concurrent blogpost (Super-
HOT kaiokendev (2023)) that also interpolates positional encoding in RoPE to extend the context
window from 2K to 8K. Recently, open source community picks it up in Reddit post ! and Github
Issues 2, which shows that fine-tuning with LoRA (Hu et al., 2021) also seems to work well. Our
paper shows a full fine-tuning with up to 65B model work well with Position Interpolation, and we
also give theoretical explanations why interpolation achieves much more stable results than extrap-
olation, by showing that the upper bound of interplated attention score is much lower than that of
extrapolated ones.

2 METHOD

2.1 BACKGROUND: ROTARY POSITION EMBEDDING (ROPE)

Transformer models require explicit positional information to be injected, typically in the form of
positional encodings, to represent the order of inputs. We consider Rotary Position Embedding
(ROPE) (Su et al., 2021), which is the position encoding used in the LLLaMA model (Touvron et al.,
2023). Given a position index m € [0, ¢) and an embedding vector x := [zg, 71,..., 241], Where
d is the dimension of the attention head, RoPE defines a vector-valued complex function f{x, m) as
follows

Using RoPE, the self-attention score
is only dependent on relative position m — 7 through trigonometric functions. Here q and k are the
query and key vector for a specific attention head. At each layer, RoPE is applied on both query and
key embeddings for computing attention scores.

2.2 DIRECT EXTRAPOLATION

While the attention score in RoPE only depends on the relative positions, which is what we want,
its extrapolation performance is not great . In particular, when directly extending to larger context
windows unseen in the training, the perplexity may shoot up to very high numbers (i.e., > 10%),
comparable to untrained models.

Ideally, we want to see the model trained on a context window of size L = 2048 to still work
reasonably well on longer context window, but may not have the capability to leverage information
that appears beyond L. For example, to answer a question located at 3000, the model trained on
maximal window size of I = 2048 cannot leverage evidences provided at location 0, but still
can leverage the evidences provided at location 2900. In contrast, in reality we see catastrophic
behaviors, i.e., question at location 3000 cannot be answered correctly, even if the evidences are
located at location 2900.

What is the reason behind? How could this happen if the attention score a,,,—,, decays as the relative
distance |m — n/| increases, according to Section 3.4.3 of (Su et al., 2021), and content from very
far distances should not matter that much? It turns out that the upper bound derived in Section 3.4.3
of (Su et al., 2021) may be too loose: while it indeed decays with respect to |m — nl, the bound
can still be quite large (i.e., the bound can be critically depends on the magnitude of v;) and thus
vacuous. In fact, if we treat all trigonometric functions as basis functions (i.e, ¢;(s) := #93), and
think about Eqn. 2 as basis expansion as the following:

where s is the positional span between a query and a key and h; := (ga; + igaj+1){k2j — tk2j+1)
are complex coefficients depending on q and k (here the definition of h; is exactly the same as the
definition of k; in Sec 3.4.3 in RoPE (Su et al., 2021)). Now the the issue becomes clear: as shown
in Fig. 2, a, can be small in magnitude in the range of [0, 2048], but gives huge values out of the
region. The underlying reason is that the trigonometric family {¢;} (with sufficiently large d) is
a universal approximator and can fit any arbitrary functions. Therefore, for a, there always exist
coefficients {h;} (i.e. key and query) that corresponds to small function values in [0, 2048] but

much larger in regions beyond.

---
"""

prompt = LONG
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# run inference with top-k sampling
with torch.inference_mode():
    start = time.time()
    generated_sequences = neuron_model.sample(input_ids, sequence_length=4096, top_k=50)
    elapsed = time.time() - start

generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]
print(f'generated sequences {generated_sequences} in {elapsed} seconds')

Error log:

{
    "name": "StopIteration",
    "message": "",
    "stack": "---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
/home/ubuntu/efs_project/inf2/meta-llama-2-13b-sampling.ipynb Cell 18 line 8
      <a href='vscode-notebook-cell://ssh-remote%2Binf2-nv/home/ubuntu/efs_project/inf2/meta-llama-2-13b-sampling.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> with torch.inference_mode():
      <a href='vscode-notebook-cell://ssh-remote%2Binf2-nv/home/ubuntu/efs_project/inf2/meta-llama-2-13b-sampling.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>     start = time.time()
----> <a href='vscode-notebook-cell://ssh-remote%2Binf2-nv/home/ubuntu/efs_project/inf2/meta-llama-2-13b-sampling.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>     generated_sequences = neuron_model.sample(input_ids, sequence_length=2048, top_k=50)
      <a href='vscode-notebook-cell://ssh-remote%2Binf2-nv/home/ubuntu/efs_project/inf2/meta-llama-2-13b-sampling.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D?line=8'>9</a>     elapsed = time.time() - start
     <a href='vscode-notebook-cell://ssh-remote%2Binf2-nv/home/ubuntu/efs_project/inf2/meta-llama-2-13b-sampling.ipynb#X33sdnNjb2RlLXJlbW90ZQ%3D%3D?line=10'>11</a> generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/llama/model.py:210, in LlamaForSampling.sample(self, input_ids, sequence_length, start_ids, top_k, top_p, eos_token_override, temperature, streamer)
    207         # Sequence length cannot be greater than n_positions
    208         sequence_length = min(sequence_length, self.max_positions)
--> 210 result = sampling.sample_llama(
    211     self, input_ids, start_ids, sequence_length,
    212     eos_token_id=self.config.eos_token_id if eos_token_override is None else eos_token_override,
    213     top_k=top_k, top_p=top_p, temperature=temperature, streamer=streamer
    214 )
    216 if offset != 0:
    217     result = result[:, offset:]

File /opt/conda/envs/inf2/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/sampling.py:243, in sample_llama(model, input_ids, start_ids, sequence_length, eos_token_id, top_k, top_p, temperature, streamer)
    241 _, start = input_ids.shape
    242 cache_ids = torch.arange(start, dtype=torch.int32)
--> 243 next_token_scores = model(input_ids, cache_ids, start_ids)
    244 return sample_loop_llama(
    245     model, input_ids, start_ids, next_token_scores, sequence_length, eos_token_id, top_k, top_p, temperature, streamer
    246 )

File /opt/conda/envs/inf2/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/llama/model.py:179, in LlamaForSampling.forward(self, input_ids, cache_ids, start_ids)
    176 hidden = hidden.transpose(0, -1).contiguous()
    178 if context_length > 1:
--> 179     logits = self.context(hidden, cache_ids, start_ids)
    180 else:
    181     logits = self.decoder_lm_head(hidden, cache_ids, start_ids)

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/llama/model.py:163, in LlamaForSampling.context(self, hidden, cache_ids, start_ids)
    161     cache_ids = torch.as_tensor([i], dtype=torch.int32)
    162     hidden_slice = hidden[:, i:i+1].contiguous()
--> 163     logits = self.decoder_lm_head(hidden_slice, cache_ids, start_ids)
    165 return logits

File /opt/conda/envs/inf2/lib/python3.10/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, **kwargs)
   1190 # If we don't have any hooks, we want to skip the rest of the logic in
   1191 # this function, and just call forward.
   1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1193         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194     return forward_call(*input, **kwargs)
   1195 # Do not call functions when jit is used
   1196 full_backward_hooks, non_full_backward_hooks = [], []

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/decoder.py:186, in DecoderLmHeadForSamplingNoEmbedding.forward(self, *inputs)
    184 sequence_length = hidden.shape[sequence_dim]
    185 if sequence_length == 1:
--> 186     return self.forward_single(*inputs)
    187 if sequence_length % self.n_active_tokens:
    188     raise ValueError(f'sequence_length={sequence_length} cannot be divided by '
    189                      f'n_active_tokens={self.n_active_tokens}')

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/decoder.py:173, in DecoderLmHeadForSamplingNoEmbedding.forward_single(self, *inputs)
    165 \"\"\"
    166 Fast-path forward function which avoids as much overhead as possible.
    167 
   (...)
    170 etc.
    171 \"\"\"
    172 _, cache_ids, *_ = inputs
--> 173 bucket_id = self.program.find_bucket_id(cache_ids.item())
    174 if self.use_executor:
    175     return self.program.execute(bucket_id, *inputs, return_ranks=self.return_ranks)

File /opt/conda/envs/inf2/lib/python3.10/site-packages/transformers_neuronx/decoder.py:903, in DecoderProgram.find_bucket_id(self, length)
    902 def find_bucket_id(self, length):
--> 903     return next(idx for idx, npos in enumerate(self.n_positions_list) if npos >= length)

StopIteration: "
}
yapweiyih commented 9 months ago

I use the following and it works:

LlamaForSampling.from_pretrained(
    saved_dir,
    batch_size=1,
    tp_degree=24,
    n_positions=4096,
    context_length_estimate=3500,
    amp="f16",
)