Using distributed or parallel set-up in script?: No
Who can help?
@ArthurZucker @younesbelkada
Information
[ ] The official example scripts
[X] My own modified scripts
Tasks
[ ] An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
[X] My own task or dataset (give details below)
Reproduction
I am using FlashAttentionv2 and BetterTransformers that just came out with transformers v4.34. ATM when I enable FlashAttention using the code snippet below, I see that:
The output increases from ~1k tokens to ~3k tokens
All of the output becomes repetitive gibberish. But without FlashAttention v2 and BetterTransformers I get reasonable answers.
Why would FlashAttention appear to cause this behaviour change. IIUC FlashAttention mainly improves handling longer contexts and makes computation more efficient, which shouldn't affect model quality.
''' Test prompt '''
# from https://github.com/facebookresearch/llama-recipes/blob/main/examples/chat_completion/chats.json
prompt = [[
{"role": "user", "content": "I am going to Paris, what should I see?"},
{
"role": "assistant",
"content": "Paris, the capital of France, is known for its stunning architecture, art museums, historical landmarks, and romantic atmosphere. Here are some of the top attractions to see in Paris:1. The Eiffel Tower: The iconic Eiffel Tower is one of the most recognizable landmarks in the world and offers breathtaking views of the city. 2. The Louvre Museum: The Louvre is one of the world's largest and most famous museums, housing an impressive collection of art and artifacts, including the Mona Lisa. 3. Notre-Dame Cathedral: This beautiful cathedral is one of the most famous landmarks in Paris and is known for its Gothic architecture and stunning stained glass windows.These are just a few of the many attractions that Paris has to offer. With so much to see and do, it's no wonder that Paris is one of the most popular tourist destinations in the world."
},
{"role": "user", "content": "What is so great about #1?"}
]]
import json
from typing import List, Literal, TypedDict
Role = Literal["user", "assistant"]
class Message(TypedDict):
role: Role
content: str
Dialog = List[Message]
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
def format_tokens(dialogs, tokenizer):
prompt_tokens = []
for dialog in dialogs:
if dialog[0]["role"] == "system":
dialog = [
{
"role": dialog[1]["role"],
"content": B_SYS
+ dialog[0]["content"]
+ E_SYS
+ dialog[1]["content"],
}
] + dialog[2:]
assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
[msg["role"] == "assistant" for msg in dialog[1::2]]
), (
"model only supports 'system','user' and 'assistant' roles, "
"starting with user and alternating (u/a/u/a/u...)"
)
"""
Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
Here, we are adding it manually.
"""
dialog_tokens: List[int] = sum(
[
tokenizer.encode(
f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
) + [tokenizer.eos_token_id]
for prompt, answer in zip(dialog[::2], dialog[1::2])
],
[],
)
assert (
dialog[-1]["role"] == "user"
), f"Last message must be from user, got {dialog[-1]['role']}"
dialog_tokens += tokenizer.encode(
f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
)
prompt_tokens.append(dialog_tokens)
return prompt_tokens
num_responses= 32
# time generate function for responses
import time
t0 = time.time()
output = model.generate(tokens.repeat([1,1]), max_new_tokens=256, num_return_sequences=num_responses, do_sample=True, top_k=5, top_p=0.9,pad_token_id=tokenizer.eos_token_id)
t1 = time.time()
total = t1-t0
print("generate total time:", np.round(total, 2), "secs")
print_result = True
if print_result:
for i in range(len(output)):
print("\n\nOutput:", i+1, "---------------------------------------------------------\n")
output_text = tokenizer.decode(output[i], skip_special_tokens=True, early_stopping=True)
print(f"length of output: {len(output_text)}")
if i == 0:
print(output_text.split('[/INST]')[0])
print('\n\n')
print(output_text.split('[/INST]')[-1])
Expected behavior
Expected output:
Output: 1 ---------------------------------------------------------
length of output: 1932
[INST] I am going to Paris, what should I see?
The Eiffel Tower is considered one of the most iconic landmarks in Paris and one of the most recognizable symbols of France. Here are some reasons why it's so great:
1. Unique Design: The Eiffel Tower is an engineering marvel with its lattice-style design, which was revolutionary for its time. It was designed by Gustave Eiffel and his company for the 1889 World's Fair, held in Paris.
2. Breathtaking Views: The Eiffel Tower offers panoramic views of the city from its observation decks on the first and second levels. Visitors can see many of Paris's famous landmarks, such as the Arc de Triomphe, the Champs-Élysées, and the Seine River.
3. Historical Significance: The Eiffel Tower was a symbol of French engineering and innovation during the late 19th century. It was also a symbol of French culture and art, as it was the tallest structure in the world at the time of its construction.
4. Romantic Atmosphere: The Eiffel Tower is often associated with romance, thanks to its
System Info
transformers
version: 4.31.0Who can help?
@ArthurZucker @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
I am using FlashAttentionv2 and BetterTransformers that just came out with transformers v4.34. ATM when I enable FlashAttention using the code snippet below, I see that:
Why would FlashAttention appear to cause this behaviour change. IIUC FlashAttention mainly improves handling longer contexts and makes computation more efficient, which shouldn't affect model quality.
Expected behavior
Expected output:
Actual output: