huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.76k stars 25.78k forks source link

Finetuning Whisper with prompts #24272

Open AvivSham opened 1 year ago

AvivSham commented 1 year ago

Feature request

Training code implementation for finetuning Whisper using prompts.

Hi All, I’m trying to finetune Whisper by resuming its pre-training task and adding initial prompts as part of the model’s forward pass. I saw this amazing tutorial, however, it does not contain a section about using prompts as part of the fine-tuning dataset.

Motivation

We witness that Whisper is not acting as expected when transcribing with prompts. Sometimes the output is blank text and on other occasions the output text contains reoccurrence. We want to solve such behaviors by fine-tuning Whisper with prompts.

Your contribution

Open for ideas.

amyeroberts commented 1 year ago

Hi, thanks for raising an issue!

Questions about custom training and model performance are best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

If you believe the behaviour is due to a bug in the model, then please share all the necessary information so that we can reproduce the issue on our side: running environment; minimal code snippet. And full details about the observed behaviour e.g. example outputs and the expected behaviour.

AvivSham commented 1 year ago

Hi @amyeroberts, Thank you for your fast response. I have already opened a thread in the forum. I agree that this is not a direct bug, but also the current behavior of Whisper does not make any sense (blank transcribes + repetitions). How should I proceed from here?

Thanks!

amyeroberts commented 1 year ago

@AvivSham You should wait to see if anyone replies to your post in the forum. I'd also suggest checking out the discord, as it's active and there's lots of people sharing ideas and helping one another with projects.

AvivSham commented 1 year ago

How can I enter the discord server? Can you please share URL / QRcode? I tried the following link but it seems to be invalid.
BTW this issue can be marked as a feature request since (as for now) I did not see a relevant code for fine-tuning Whisper with prompts.

amyeroberts commented 1 year ago

Hi @AvivSham - the discord link you shared (here), is the same one I would use, and works for me. What do you mean by it being 'invalid'?

AvivSham commented 1 year ago

@amyeroberts Please see the attached image.

image
amyeroberts commented 1 year ago

@AvivSham Oh no :/ I've tried making a new account with the previous link and it worked, so I'm not sure what's going on unfortunately. I'll see on our end if there's any known issues / resolutions.

In the meantime, do you already have a discord account or are you able to make one independent of this server invite?

AvivSham commented 1 year ago

@amyeroberts Hi Amye, I re-opened this issue since I did not get any support over discord and HF forum. I think that this issue is in high priority for DL practitioners! can you please help with it?

amyeroberts commented 1 year ago

Hi @AvivSham,

So that we can figure out whether this is an issue on our side, could you confirm that you have an active discord account or are able to create one (independent of the HF invite link)?

AvivSham commented 1 year ago

Hi @amyeroberts I feel like you are totally ignoring my questions :/ See my lastest message, please.

amyeroberts commented 1 year ago

Hi @AvivSham,

I'm certainly not ignoring the questions. Please understand the we're all very busy and trying to address as many issues as possible. As the previous thread was discussing technical difficulties in joining the discord server, I'd understood that this was the ongoing issue, my apologies for misunderstanding.

With regards to training whisper with prompts, then same case applies as in my first comment. Unless there's a specific behaviour which you believe to be a bug of the model, this is a question for the forums / discord and not github issues. Not having responses isn't justification for posting in github issues as it just isn't a scalable solution.

Lauler commented 1 year ago

Sorry for reviving this thread. I was going to create my own issue (but saw this one already existed). I do actually think this is a legitimate feature request based off of discussions in a pull request that is related to this issue. The original post is however not worded in the best manner to explain what is requested and to demonstrate the general benefit.

The relevant PR (https://github.com/huggingface/transformers/pull/22496) added prompt support for Whisper inference. In the PR a user asked whether similar support could be added for finetuning. @hollance and @sanchit-gandhi replied with ideas of how prompting support during training could be implemented and a suggestion to start a new issue (https://github.com/huggingface/transformers/pull/22496#issuecomment-1557501336, https://github.com/huggingface/transformers/pull/22496#issuecomment-1556934882) for the feature.

My alternative wording of this feature request:

Feature request

Huggingface recently added support for prompting Whisper with model.generate() (see https://github.com/huggingface/transformers/issues/22395, https://github.com/huggingface/transformers/pull/22496). In the PR, there were discussions (https://github.com/huggingface/transformers/pull/22496#issuecomment-1557501336) of adding similar support for including parts of the previous (text) context when training and finetuning the model. It was suggested a new issue be started for the feature request, though no one ended up creating the issue.

The Whisper paper seems to suggest the general pretraining process was:

Relevant parts of the paper:

Since our decoder is an audio-conditional language model, we also train it to condition on the history of text of the transcript in the hope that it will learn to use longer-range text context to resolve ambiguous audio. Specifically, with some probability we add the transcript text preceding the current audio segment to the decoder’s context. [...] When a final transcript segment is only partially included in the current 30 second audio chunk, we predict only its start time token for the segment when in timestamp mode, to indicate that the subsequent decoding should be performed on an audio window aligned with that time, otherwise we truncate the audio to not include the segment.

Support for prompting in training/finetuning has also been requested and discussed on the HF forums:

https://discuss.huggingface.co/t/adding-prompt-context-to-whisper-with-huggingface-transformers/31070 https://discuss.huggingface.co/t/finetuning-whisper-with-prompts/43053

I believe being able to include previous context in finetuning would be a useful feature. It would also enable users to finetune the model in a manner that is consistent with how it was pretrained (i.e. how the final segment is handled when it is only partially included in the audio). This is something that may have an effect on the robustness of finetuned models when it comes to long form transcription and timestamps.

The reason OpenAI preprocessed data in this manner during finetuning is likely because it would best mimic the kind of data it would see during inference (i.e. audio being chunked where it regularly cuts off in the middle of sentences and/or words).

amyeroberts commented 1 year ago

@Lauler OK, I see. Thanks for taking the time to write up such a clear explanation and to link to all the relevant issues, PR and discussions.

As this is a feature request I'll re-open and tag it as such :)

cc @sanchit-gandhi

sanchit-gandhi commented 1 year ago

Hey @AvivSham and @Lauler, really cool to see such excitement around developing Whisper fine-tuning further! Thanks both for the motivations for the feature request.

In terms of what we have to do to make the fine-tuning script work with prompted fine-tuning, it's super simple. All we have to do is update the prepare_dataset function to encode the prompts, the target text, and then combine them to get the labels:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode prompts to prompt ids - we assume that the dataset has a column `"prompt"` that contains the prompt for each example
    prompt_ids = tokenizer.get_prompt_ids(batch["prompt"])

    # encode target text to token ids 
    token_ids = tokenizer(batch["sentence"]).input_ids

    # combine them to get our labels
    batch["labels"] = prompt_ids + token_ids
    return batch

You can try this with a toy example:

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

prompt_ids = processor.get_prompt_ids("Nokia")
token_ids = processor.tokenizer(" No kea phones are great").input_ids
labels = prompt_ids + token_ids

# let's check how the labels are decoded
print(processor.decode(labels))

Print Output:

'<|startofprev|> Nokia<|startoftranscript|><|notimestamps|> No kea phones are great<|endoftext|>'

-> we see the prompt Nokia nestled between the prompt token id and the BOS token id, and the target text nestled between the BOS and EOS token ids, which is the expected behaviour.

Now the tricky bit about getting this working more generally is how we get the prompt column in our dataset - we can't assume that every dataset is going to have examples with a trio of (audio, target text, prompt), most ASR datasets only have (audio, target text).

Maybe we could start with the LibriSpeech ASR dataset: since the dataset is taken from recorded samples of audio book narration, each sentence can be prompted with the previous one? i.e. if you have a dataset:

(audio_1, text_1)
(audio_2, text_2)
...
(audio_n, text_n)

You could augment it as:

(audio_2, text_2, prompt=text_1)
(audio_3, text_3, prompt=text_2)
...
(audio_n, text_n, prompt= text_n-1)

Since we know the dataset samples are recorded sequentially? Here we just need to check that text_i follows on from text_i-1 by making sure it comes from the same speaker

I think this would be a good starting point for adapting the fine-tuning script, but I don't think there's a way of generalising it to work with all datasets since we don't always have the prompts available?

Lauler commented 1 year ago

For datasets where audio snippets are sequential (audiobooks) that makes sense!

A complementary solution that could be more general in nature is to perhaps wait for the PR that adds support for encoding timestamp tokens as is (https://github.com/huggingface/transformers/pull/24476).

A general preprocessing step involving encoding a separate "timestamp_encoded"-column could then perhaps work for both datasets with sequential audio snippets (LibriSpeech audiobooks), and those who already have longer audio samples with more granular timestamp information.

Then in the case of LibriSpeech (and any dataset with sequential audio snippets) the following preprocessing guide would apply:

  1. Extract the length of each audio sample and create a duration column.
  2. Encode the decoder input as "<|startofprev|>" + text_n-1 + "<|startoftranscript|><|timestamps|> <|0.00|>" + text_n + "<|duration_n|><|endoftext|>" as a separate column.

If it would be possible to train with timestamps, then a conceptually similar approach would apply to those who already have datasets with more granular timestamps. Their preprocessing would consist of creating a similar suitable column where the prompt and timestamps are already encoded properly.

I'm aware that existing audio datasets on the Hub are currently mostly composed of single sentences. However, I think this is increasingly going to change with time. The question for these new datasets then becomes:

Right now it is not obvious how such information should best be included in a HF dataset. As an example, our organization published an audio dataset of parliamentary recordings. In its original form we have sentence aligned these transcripts. However in the published version that is on the Hub, we concatenated sequential sentences and coresponding audio snippets until they fill up as much as possible of a 30s bucket.

We have been discussing the most flexible way of adding the more granular sentence-level timestamps to the dataset, with our top two choices being:

I think the first option is probably the best, since it's model agnostic, and it will allow us and any user to remix and re-encode the dataset in whatever way they need.

A separate question: Would the prompt ids automatically be masked in the loss calculation in the current Whisper implementation?

AvivSham commented 1 year ago

Hi @sanchit-gandhi, Thank you for your reply! Following your reply:

https://github.com/huggingface/transformers/issues/24272#issuecomment-1633834410

Do we manually need to mask the prompt_ids since we do not want to include those when calculating the loss? Is it dealt with internally (by looking inside the code I did not find such masking)? What is the right approach here?

Thanks in advance.

samuelazran commented 1 year ago

Hi Aviv, in the paper it says: During training it should “mask out the training loss over the previous context text and train the model to predict all other tokens”.

I'm not sure how to implement it with HF Trainer. But it is an important feature (I posted on it in the forum half a year ago) and I hope you can test some ideas and see what works.

AvivSham commented 1 year ago

@samuelazran Thank you for your comment. I'm looking for an official guide here since it is a bit tricky to integrate the implementation with HF. @sanchit-gandhi @Lauler Maybe you can help us with it? how should we approach this? Do we manually need to mask the prompt_ids since we do not want to include those when calculating the loss? Is it dealt with internally (by looking inside the code I did not find such masking)? What is the right approach here?

sanchit-gandhi commented 1 year ago

If you're taking multiple audio samples < 30s and combining them to give your prompt and target text, you probably don't need the timing information within each sample. You can get the length of each sample by measuring the length of the audio array, and dividing it by the sampling rate (assuming there's little to no silence):

audio_length_s = len(audio["array"]) / audio["sampling_rate"]

Timing information would be useful if you had the opposite situation, where you had very long samples that you wanted to split up into smaller ones. Here, you would split on appropriately chosen timestamps.

  • How does a user best add timestamp information to their dataset that has longer audio snippets with granular timestamps?

I'm not sure I fully understand this question - you want to take long audio samples and add timestamp information to them? Or you want to format audio samples that have existing timestamp information?

Also agree that the first option you've proposed is best! I don't think we can make a very general recommendation here as to the format your data should be in, since this is quite a niche application of fine-tuning and one that is conditional on the form of your original data. But the design you've proposed sounds sensible for your use case!

Would the prompt ids automatically be masked in the loss calculation in the current Whisper implementation?

No they wouldn't - we would need to update this. I know that @peregilk and co from NbAiLab have done something similar in Flax: https://github.com/NbAiLab/nb-whisper/blob/352bf2d0efb073405c90fb4ef048a5d52b6128b6/run_nb_flax_speech_recognition_seq2seq_streaming_dev.py#L579-L582

We would need to do the same for the PyTorch code. Would also be interested in hearing from @peregilk how you constructed your prompts + text pairs! Are we on the right lines by pairing consecutive samples of our dataset together?

peregilk commented 12 months ago

Sure @sanchit-gandhi. Our dataset consists of multiple different sources, subtitles, parliament transcripts, audio books and created datasets. In some cases we do have the text directly preceding the current sample. In our dataset, we simply add this as a separate "pretext"-column. In a lot of scenarios this information is not available, and here we simply leave that field empty.

We have not added timestamps to the pretext (yet). I see your point (with reference to the article) @Lauler regarding not predicting the end timestamp. We have not tried that, but one of our datasets are cut on pauses, and here we have a lot of incomplete sentences (ie not ending with punctation and not starting with capital letter). This seems to be well handled by the model.

We ended up with a dataset-format with multiple columns for each audio-clips. One sample could for instance have both text, timestamp_text, pretext, english_translation, nynorsk_transcription etc. For other samples, very few of these are filled out. This means that we for one audio-clip can generate 1-5 training samples. We have modified the data loader to be able to handle this so that we can generate the actual prompt on the fly. I can share this with you @Lauler if you are interested.

@AvivSham Personally I found the masking a bit tricky. This snippet helped me a lot in understanding what was going on. Maybe you can reuse it: https://github.com/NbAiLab/nb-whisper/blob/352bf2d0efb073405c90fb4ef048a5d52b6128b6/run_nb_flax_speech_recognition_seq2seq_streaming_dev.py#L1692-L1697

samuelazran commented 11 months ago

@samuelazran Thank you for your comment. I'm looking for an official guide here since it is a bit tricky to integrate the implementation with HF. @sanchit-gandhi @Lauler Maybe you can help us with it? how should we approach this? Do we manually need to mask the prompt_ids since we do not want to include those when calculating the loss? Is it dealt with internally (by looking inside the code I did not find such masking)? What is the right approach here?

Hi @AvivSham , were you able to make some progress on training with prompts? if you and others are interested, let's combine forces and work on it until we find someone from Huggingface who can help.

Does anyone know who is a relevant person from HF to give us ideas or directions? maybe @patrickvonplaten?

AvivSham commented 11 months ago

Hi @samuelazran, there is no significant progress from my end. Maybe someone from HF may help.

sanchit-gandhi commented 11 months ago

Thanks for the comprehensive summary @peregilk! Cool to see that you're still super up for implementing this @samuelazran. I personally won't have time to generalise the fine-tuning script to use the prompted tokens in the training objective, but I'm more than happy to answer any questions / queries if you'd like to have a go yourself.

IMO the most challenging bit of this integration will be constructing the (prompt, text) pairs -> I can't see a way of making this generalise across all datasets? Given a data sample (text_i, audio_i), how can we know the corresponding prompt for the target text? Most ASR datasets are constructed with independent (text, audio) samples, so it's not trivial to find the text prompt for each sample, if it even exists.

If you'd like to pursue this, I'd recommend starting with the LibriSpeech dataset (for which I left some details here: https://github.com/huggingface/transformers/issues/24272#issuecomment-1633834410)

anderleich commented 11 months ago

Hi,

I found this thread really interesting. Last month I suggested what it could be a simple starting point to prepare the dataset with prompts in the Huggingface Forum. See https://discuss.huggingface.co/t/finetuning-whisper-with-prompts/43053/3?u=andercorral

@sanchit-gandhi I think this could make your solution consistent with the API. I've also added prompts with some probability as stated in the original paper:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode prompts and target text to prompt ids - we assume that the dataset has a column `"prompt"` that contains the prompt for each example
    if random.uniform(0,1) > 0.5:
        token_ids = tokenizer(batch["sentence"], batch["prompt"]).input_ids
    else:
        token_ids = tokenizer(batch["sentence"]).input_ids

    batch["labels"] = token_ids
    return batch
samuelazran commented 11 months ago

@sanchit-gandhi

Thanks for the comprehensive summary @peregilk! Cool to see that you're still super up for implementing this @samuelazran. I personally won't have time to generalise the fine-tuning script to use the prompted tokens in the training objective, but I'm more than happy to answer any questions / queries if you'd like to have a go yourself.

IMO the most challenging bit of this integration will be constructing the (prompt, text) pairs -> I can't see a way of making this generalise across all datasets? Given a data sample (text_i, audio_i), how can we know the corresponding prompt for the target text? Most ASR datasets are constructed with independent (text, audio) samples, so it's not trivial to find the text prompt for each sample, if it even exists.

If you'd like to pursue this, I'd recommend starting with the LibriSpeech dataset (for which I left some details here: #24272 (comment))

Thank you for your reply. However, I don't think that providing the dataset in a certain format is the biggest challenge. I'd be glad to be able to provide the data in any way that works. The most important thing is to be able to have a batch with multiple items some contains prompts and some do not, the question is how to handle it during training.

sanchit-gandhi commented 11 months ago

Hey @samuelazran - for creating batches with prompted and non-prompted data, you can do the random switching in the prepare_dataset function as shown very nicely by @anderleich

samuelazran commented 10 months ago

Hey @samuelazran - for creating batches with prompted and non-prompted data, you can do the random switching in the prepare_dataset function as shown very nicely by @anderleich

But during training, how do you ensure calculating loss only on the tokens that come after the prompt? We need to make sure the model will not generate the prompt part and start generating from the transcript labels or at least ignore the loss over the prompt tokens as in the original paper:

"We only mask out the training loss over the previous context text, and train the model to predict all other tokens." https://arxiv.org/pdf/2212.04356.pdf

This is the main challenge. I'm looking for insights about implementing it with Huggingface Transformers and especially using Transformers Trainer.

sanchit-gandhi commented 10 months ago

Hey @samuelazran - the masking part it pretty easy. Here, we just set the labels to -100 (a very large negative number) for the prompt, so that they ignored from the loss. Here's how you would do this in numpy: https://github.com/NbAiLab/nb-whisper/blob/352bf2d0efb073405c90fb4ef048a5d52b6128b6/run_nb_flax_speech_recognition_seq2seq_streaming_dev.py#L579-L582

You could port this to torch and it would be quite straightforward this way! I still maintain that getting the dataset in the correct format is the toughest part to generalise.

AvivSham commented 10 months ago

I have a follow-up question related to finetuning whisper in general. Whisper consumes lots of GPU memory (>30GB for medium sized model). What is the current way to use DDP or FSDP with Whisper? (found this related issue) When we tried these strategies (on 4 v100 GPUs 16GB each) we witnessed that most of the memory is stored on the first GPU instead of equally balanced between all four. We aim to train the model using multiple V100 16GB cards rather than large memory GPUs, which was possible if the memory was equally spread across multiple cards. This is an extremely annoying problem due to the shortage of large memory GPUs. Can you please help? @connor-henderson @sanchit-gandhi

sanchit-gandhi commented 10 months ago

torch.distributed.launch should still work (although will be deprecated soon in place of torchrun): https://github.com/huggingface/transformers/tree/main/examples/pytorch/speech-recognition#multi-gpu-whisper-training

AvivSham commented 10 months ago

Hi @sanchit-gandhi, Thank you for your quick response. Please let me elaborate on the issue I'm facing. As mentioned before we are trying to fine-tune Whisper with prompts on 4 V100 GPU machine. Following your suggestion we replaced Adam optimizer with Adafactor which reduced the memory footprint a lot (thanks for the tip!). However, now we are facing another issue, when training whisper with prompt during the evaluation (that uses generate method) the prompts are passed as a list of independent prompts which is not supported (as far as I know).
Note that this behavior is specifically related to DDP training and occurs during evaluation, if we train using single GPU / CPU this error is not raised.

The config we use:

training_args = Seq2SeqTrainingArguments(
    output_dir="./outputs/foo_foo",  # change to a repo name of your choice
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=5e-6,
    warmup_steps=0,
    # max_steps=4000,
    max_steps=10000,
    gradient_checkpointing=False,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=1,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=10000,
    eval_steps=10000,
    logging_steps=1,
    report_to=["none"],
    load_best_model_at_end=True,
    metric_for_best_model="eval_validation_wer",
    greater_is_better=False,
    push_to_hub=False,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    optim="adafactor"
)

The raised error:

decoder_input_ids_start = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
TypeError: only integer tensors of a single element can be converted to an index

The reason for this error is this line, since prompt_ids are list of prompts with length equal to per_device_eval_batch_size * NUM_DEVICES (which equals 4 in our case). After unpacking decoder_start_token_id is list type instead of int.

Can you please help resolve this issue? Thanks!

sanchit-gandhi commented 10 months ago

Thanks for explaining the issue! If I understand correctly, the problem essentially boils down to prompted generation not working with batching? Or does generation work with batching, just not with your DDP set-up? I'm not entirely sure what you mean by:

the prompts are passed as a list of independent prompts

The prompts should be a torch tensor of shape bsz, seq_len?

Perhaps if you have a reproducible code-snippet to isolate the behaviour you are eluding to for evaluation (i.e. explicitly calling model.generate with your prompt ids) I can take a deeper dive! Feel free also to open a new issue to track this problem, since we're diverging a bit from the original issue thread and I want to make sure both are tracked appropriately. Thanks @AvivSham!

AvivSham commented 9 months ago

Hi @sanchit-gandhi, I just opened new issue to keep this thread clean, would appreciate your response.

samuelazran commented 9 months ago

Hey @samuelazran - the masking part it pretty easy. Here, we just set the labels to -100 (a very large negative number) for the prompt, so that they ignored from the loss. Here's how you would do this in numpy: https://github.com/NbAiLab/nb-whisper/blob/352bf2d0efb073405c90fb4ef048a5d52b6128b6/run_nb_flax_speech_recognition_seq2seq_streaming_dev.py#L579-L582

You could port this to torch and it would be quite straightforward this way! I still maintain that getting the dataset in the correct format is the toughest part to generalise.

Thank you for the reference. This is great.

I'm wondering about the padding issue: since we pad the sequences from the right to the longest one since the sequences are of different lengths, How do you handle prompts with different lengths in the same training batch? should it be padded? Should we pad to the left? was it like that during pretraining? Any reference or source will help a lot.

sanchit-gandhi commented 9 months ago

We still right pad everything after the EOS token. You just need to mask the prompt between the <|startofprev|> and <|startoftranscript|> in the attention mask.

So you'll get batches that look like:

<|startofprev|> Shorter prompt.<|startoftranscript|> Hey what's up?<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
<|startofprev|> Longer prompt here. Notice that we still right pad.<|startoftranscript|> Hey what's up?<|endoftext|>

And attention masks that look like:

1 0 0 1 1 1 1 1 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 1 1 1 1 1

=> the 1001 at the start of the first item is where we have masked the prompt (the two zeros) in-between the <|startofprev|> and <|startoftranscript|> tokens. The zeros at the end are where we mask the extra padding <|endoftext|> token

AvivNavon commented 8 months ago

And attention masks that look like:

1 0 0 1 1 1 1 1 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 1 1 1 1 1

=> the 1001 at the start of the first item is where we have masked the prompt (the two zeros) in-between the <|startofprev|> and <|startoftranscript|> tokens. The zeros at the end are where we mask the extra padding <|endoftext|> token

@sanchit-gandhi Could you please explain why we need to mask the prompt in the attention mask, and not just mask the loss over the prompt token (by setting the label to -100 on prompt tokens)?

sanchit-gandhi commented 8 months ago

We don't want to train the model to generate prompts during inference time, only the transcription of the audio. If we train on the prompt tokens, the model will learn to generate prompts, which is not the intended behaviour. Thus, we only train on the text tokens, and mask the prompt. This is also how OpenAI train the model. See page 3 of the Whisper paper for details:

Screenshot 2023-11-27 at 10 50 26
AvivNavon commented 8 months ago

Thanks @sanchit-gandhi I understand that the labels should be set to -100 and then the loss will be masked on prompt tokens. But why do we need to modify the attention_mask like you suggeted?

sanchit-gandhi commented 7 months ago

Super! Glad the labels logic makes sense. Sorry you're entirely right about the attention mask! We don't change the attention mask for the prompt (I got carried away with my example 😅). You can see the PyTorch code for this here: https://github.com/huggingface/distil-whisper/blob/914dcdf3919552d5a3826a9d5db99b059ddcc16e/training/run_distillation.py#L343

Feel free to copy this update data collator for your fine-tuning script!