huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.74k stars 26.45k forks source link

Allow setting different decoder_start_token_ids for each item in a batch in the generate function. #28763

Open dpernes opened 8 months ago

dpernes commented 8 months ago

Feature request

@gante The generate function has a decoder_start_token_id argument that allows the specification of the decoder start token when generating from an encoder-decoder model (e.g. mT5). Currently, decoder_start_token_id must be an integer, which means that the same start token is used for all elements in the batch. I request that you allow the specification of different start tokens for each element of the batch. For this purpose, decoder_start_token_id must be a tensor with shape (batch_size,).

Motivation

Some multilingual encoder-decoder models use the decoder_start_token_id to indicate the target language. Thus, this change would allow generation into multiple target languages in parallel, as illustrated in the code below.

Your contribution

import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))

article_text = """Videos that say approved vaccines are dangerous and cause autism, cancer or infertility are among those that will be taken down, the company said.  The policy includes the termination of accounts of anti-vaccine influencers.  Tech giants have been criticised for not doing more to counter false health information on their sites.  In July, US President Joe Biden said social media platforms were largely responsible for people's scepticism in getting vaccinated by spreading misinformation, and appealed for them to address the issue.  YouTube, which is owned by Google, said 130,000 videos were removed from its platform since last year, when it implemented a ban on content spreading misinformation about Covid vaccines.  In a blog post, the company said it had seen false claims about Covid jabs "spill over into misinformation about vaccines in general". The new policy covers long-approved vaccines, such as those against measles or hepatitis B.  "We're expanding our medical misinformation policies on YouTube with new guidelines on currently administered vaccines that are approved and confirmed to be safe and effective by local health authorities and the WHO," the post said, referring to the World Health Organization."""

model_name = "csebuetnlp/mT5_m2m_crossSum_enhanced"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

get_lang_id = lambda lang: tokenizer._convert_token_to_id(
    model.config.task_specific_params["langid_map"][lang][1]
)

target_langs = ["portuguese", "spanish"]

input_ids = tokenizer(
    [WHITESPACE_HANDLER(article_text)],
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=512
)["input_ids"]
input_ids = input_ids.expand(len(target_langs), -1)   # shape (num_target_languages, num_input_tokens)

decoder_start_token_id = torch.tensor(
    [get_lang_id(t) for t in target_langs],
    dtype=input_ids.dtype,
    device=input_ids.device
)  # shape (num_target_languages,)

output_ids = model.generate(
    input_ids=input_ids,
    decoder_start_token_id=decoder_start_token_id,
    max_length=84,
    no_repeat_ngram_size=2,
    num_beams=4,
)

summaries = tokenizer.batch_decode(
    output_ids,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

print(summaries)
gante commented 7 months ago

cc @zucchini-nlp

zucchini-nlp commented 7 months ago

@dpernes Hi, if you want to specify in different decoder_start_token_ids for each element, you can do it by passing a tensor of shape (batch_size, seq_len). In your case adding this line before the generate is called will solve the issue:

decoder_start_token_id = decoder_start_token_id.unsqueeze(1) # shape (num_target_languages, 1)

dpernes commented 7 months ago

Great, thank you @zucchini-nlp! This behavior is not documented, though:

decoder_start_token_id (`int`, *optional*):
            If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.

You may want to change it to something like:

decoder_start_token_id (`Union[int, torch.LongTensor]`, *optional*):
            If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token. Optionally, use a `torch.LongTensor` of shape `(batch_size, sequence_length)` to specify a prompt for the decoder.

But why isn't this the same as passing decoder_input_ids to generate? I tried passing the same tensor as decoder_input_ids instead of decoder_start_token_id and the results do not match.

zucchini-nlp commented 7 months ago

Thanks, I added a PR extending the docs.

Regarding your question, there is a subtle difference between them. The decoder_start_token_id is used as the very first token in generation, BOS token in most cases. But decoder_input_ids are used to start/continue the sentence from them. In most cases you do not provide decoder_input_ids yourself when calling generate, so they will be filled with decoder_start_token_id to start generation from BOS.

The general format is [decoder_start_token_id, decoder_input_ids] and the generate automatically fills in decoder_start_token_id from config if you do not provide them.

tehranixyz commented 7 months ago

Hi, Is there any way to specify decoder_start_token_id during training as well? Like

outputs = model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    labels=batch["labels"],
                   decoder_start_token_id=decoder_start_token_id,
                )
loss = outputs.loss

Each batch may require a different decoder_start_token_id during training. This is because each batch has a specific input language and output language. Sometimes, the output language is and some other times it is . Changing model.config.decoder_start_token_id per each batch doesn't seem to be a good approach. Specifically, it seems it causes lots of inconsistency when using Accelerator with DeepSpeed.

zucchini-nlp commented 7 months ago

Hey @tehranixyz , you do not need to specify decoder_start_token_ids while training. All you need is to prepare the decoder_input_ids and pass it to the forward. We use the start token from model config only when we do not find decoder_input_ids from the user (see code snippet for preparing decoder input ids from labels)

tehranixyz commented 7 months ago

Gotcha! I was a bit confused by the warning saying The decoder_input_ids are now created based on the "labels", no need to pass them yourself anymore. when using EncoderDecoderModel. So in my case, I guess, as you said, I have to prepare decoder_input_ids myself by shifting labels and adding the appropriate start_token at the beginning. Many thanks!