huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.26k stars 26.85k forks source link

Truncated assistant message gets a 0 asssitant mask #34494

Open Butanium opened 2 days ago

Butanium commented 2 days ago

System Info

Who can help?

@yonigottesman

Information

Tasks

Reproduction

I modified gemma template to allow assitant_masks to work:

{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{{ '<start_of_turn>' + role + '\n'}}{% generation %}{{message['content'] | trim}}{% endgeneration %}{{ '<end_of_turn>\n' }}{% else %}{% set role = message['role'] %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}

however, if a model message gets truncated, the mask is all 0: image

from transformers import AutoTokenizer
better_template = "<copy above>"
chat = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing great, thank you!"},
    {"role": "user", "content": "What is the capital of France?"},
    {"role": "assistant", "content": "The capital of France is Paris."},
]
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
tokens = tokenizer.apply_chat_template(
    chat,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
    chat_template=better_template,
    max_length=20,
    truncation=True,
)
highlighted_tokens = [
    (
        f"<span style='color: red; border: 1px solid red; padding: 2px;'>{token.replace('<', '&lt;').replace('>', '&gt;')}</span>"
        if mask
        else token.replace("<", "&lt;").replace(">", "&gt;")
    )
    for token, mask in zip(
        tokenizer.convert_ids_to_tokens(tokens["input_ids"]), tokens["assistant_masks"]
    )
]

md = "".join(highlighted_tokens)
from IPython.display import display, HTML

display(HTML(md))

# %%
tokens["assistant_masks"]
# %%

Expected behavior

I'd expect the mask to have 1 on the partial model response

Butanium commented 2 days ago

Also @yonigottesman, assistant_mask in not converted ot a tensor even if I do

tokens = tokenizer.apply_chat_template(
    chat,
    tokenize=True,
    return_assistant_tokens_mask=True,
    return_dict=True,
    chat_template=better_template,
    return_tensors ="pt"
)
yonigottesman commented 13 hours ago

@Butanium you are right there is a bug in my code, I will fix and update. BTW, you should include the {{ '<end_of_turn>\n' }} insude the generation block, as you want the model to learn to output this string when its done

Butanium commented 8 hours ago

thank you! I edited my template to include the <end_of_turn> but not the \n as those are different tokens