huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.09k stars 27.03k forks source link

Llama-13B gives nonsensical output past 1024 tokens #22433

Closed michaelroyzen closed 1 year ago

michaelroyzen commented 1 year ago

System Info

Latest transformers main branch, Python 3.10

Who can help?

@sgu

Information

Tasks

Reproduction

import torch
import os
import transformers
from transformers import LlamaTokenizer, LlamaForCausalLM

model_path = "/home/ubuntu/LLaMa-13B"
tokenizer_path = "/home/ubuntu/LLaMa-13B"

model = LlamaForCausalLM.from_pretrained(model_path).cuda() # or something like {"": 0}
tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)

input_prompt = "some text that is over 1024 tokens"

batch = tokenizer(input_prompt, return_tensors="pt", truncation=False)

with torch.no_grad():
    out = model.generate(
        input_ids=batch["input_ids"].cuda(),
        attention_mask=batch["attention_mask"].cuda(),
        max_new_tokens=100,
        do_sample=False,
        top_k=50,
        top_p=1.0,
        temperature=1.0,
        use_cache=True
    )
print(tokenizer.decode(out[0]))

Expected behavior

Text that makes sense. Text makes sense when truncation=True. There shouldn't be any arbitrary limitation for sequence lengths greater than 1024 given that Llama was trained on 2048 sequence lengths and has rotary embeddings that should theoretically support any sequence length.

sgugger commented 1 year ago

cc @ArthurZucker and @gante

gante commented 1 year ago

Hey @michaelroyzen 👋

Double-checking -- if you print model.config.max_sequence_length, do you get 2048? If not, overwriting it would be the first thing I'd do.

Secondly, there is this ongoing PR that may be related.

If model.config.max_sequence_length == 2048 and the PR above doesn't fix it, debugging becomes trickier, as the weights and configuration files are not public. In that case, can you try to reproduce the issue with GPTNeoX (e.g. with this model)?

michaelroyzen commented 1 year ago

Hi, @gante -- thanks for getting back to me! Unfortunately, I get AttributeError: 'LlamaConfig' object has no attribute 'max_sequence_length'. It doesn't seem like it's been implemented for Llama. I converted the official weights from FB using https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py, which doesn't seem to have created a 'max_sequence_length' property in the config.

gante commented 1 year ago

Hey @michaelroyzen! Two notes:

  1. The should not be an exception, I'll submit a PR to fix it. However, upon further inspection, that would not be the problem -- the maximum pre-initialized rotary position index is hardcoded to 2048, so it's okay (it should be config. max_sequence_length, but doesn't change the issue here)
  2. I've attempted to reproduce, and I noticed you set do_sample=False. Without sampling, the results are often underwhelming. I've ran locally with sampling, and everything looks fine. LMK if the example below works well on your checkpoint/input prompt :)
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

weights_path = "your-path-to-llama"

tokenizer = LlamaTokenizer.from_pretrained(weights_path, use_auth_token=True)
input_prompt = "The cat is"
batch = tokenizer(input_prompt, return_tensors="pt")
print(batch["input_ids"].shape)

model = LlamaForCausalLM.from_pretrained(weights_path, torch_dtype=torch.float16, use_auth_token=True).cuda()

# Manual left-padding with 1024 tokens
batch["input_ids"] = torch.cat([torch.ones((1, 1024), dtype=torch.long) * model.config.eos_token_id, batch["input_ids"]], dim=1)
batch["attention_mask"] = torch.cat([torch.zeros((1, 1024), dtype=torch.long), batch["attention_mask"]], dim=1)
print(batch["input_ids"].shape)

# Set seed for reproduction
torch.cuda.manual_seed(0)

with torch.no_grad():
    out = model.generate(
        input_ids=batch["input_ids"].cuda(),
        attention_mask=batch["attention_mask"].cuda(),
        max_new_tokens=100,
        do_sample=True,
        top_k=50,
        top_p=1.0,
        temperature=1.0,
        use_cache=True
    )
print(out[0].shape)
print(tokenizer.decode(out[0]))
michaelroyzen commented 1 year ago

Thanks @gante. May I ask why it's hardcoded to 2048? This is not mentioned in the paper. And isn't the whole point of rotary embeddings to support infinite sequence length? Would it be possible to get a sequence length of 4096+ to work?

gante commented 1 year ago

@michaelroyzen Yes, rotary embeddings are, in practice, relative (and periodic!) position embeddings. See eq 12 in the original paper.

As you can see in our code, the hardcoded 2048 (now config.max_position_embeddings) is the initialization size -- they are immediately expanded upon request. They will never be the bottleneck 🙌

So... why 2048? Well, we'd have to ask the Llama creators, since they have hardcoded it in their repo 😅 It is not mentioned in the paper, as far as I can see, but I suspect they capped training at this sequence length. If my assumption is correct: while the model works beyond 2048 tokens (try changing 1024 to 2048 in the script I shared above), I would expect the quality to drop as we go beyond 2048 tokens, simply because of train-test skew :)

kechan commented 1 year ago

just wondering if there's any more issue, if you do_sample=True? I tried this on apple silicon (MPS) and got decent result. The default sample code from huggingface is missing this argument, so I suspect the default is do_sample=False. And I was (falsely) disappointed I started getting either non-sense, or repetitions.

michaelroyzen commented 1 year ago

Thank you @gante! max_position_embeddings is indeed the fix here. Closing this issue now.

gante commented 1 year ago

just wondering if there's any more issue, if you do_sample=True? I tried this on apple silicon (MPS) and got decent result. The default sample code from huggingface is missing this argument, so I suspect the default is do_sample=False. And I was (falsely) disappointed I started getting either non-sense, or repetitions.

@kechan That's correct, do_sample=False is the default, and it decreases the performance of this task in particular (open text generation) :) This blog post talks about it, and explains why.