lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.18k stars 4.45k forks source link

error Inference on for long-context text #3144

Open ZhangYuanhan-AI opened 5 months ago

ZhangYuanhan-AI commented 5 months ago

This is my demo code:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Initialize the tokenizer and model from the pretrained version on Hugging Face
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5-16k")
model = AutoModelForCausalLM.from_pretrained("lmsys/vicuna-7b-v1.5-16k")

# Prepare the text you want to infer on
text = "text" * 10000 
inputs = tokenizer(text, return_tensors="pt", max_length=16384, truncation=True)

# Generate output using the model
with torch.no_grad():
    outputs = model.generate(**inputs, max_length=16384, num_return_sequences=1)

# Decode the generated text
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(generated_text)

This is the error code:

    padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
RuntimeError: The size of tensor a (8192) must match the size of tensor b (10001) at non-singleton dimension 3

It seems that the maximum size of the causal mask is 8196: https://github.com/huggingface/transformers/blob/0290ec19c901adc0f1230ebdccad11c40af026f5/src/transformers/models/llama/modeling_llama.py#L1079

env: transformer: 4.38.2

freddyheppell commented 5 months ago

Also been getting this, both using transformers directly as you've tried, and also through FastChat.

This appears to be a bug with transformers, as I get it on 4.38.1 and 4.38.2, but it works fine on v4.31.0 (the first version to add RoPE scaling). Weridly though I can't find anything in the transformers changelog referring to RoPE until 4.38.2, but it's also broken in 4.38.1.

Gera001 commented 5 months ago

How can I solve it

Gera001 commented 5 months ago

Gera001 max_new_tokens=128 It's better to get smaller

freddyheppell commented 5 months ago

You can solve this by downgrading the transformers library to 4.31.0 (or possibly later, but 4.38.1 and 4.38.2 are definitely broken). Setting max_new_token=128 won't solve it, it will just take longer to break as it's generating fewer tokens at once.