huggingface / blog

Public repo for HF blog posts
https://hf.co/blog
2.36k stars 739 forks source link

Finetuning Whisper: validation loss keeps decreasing while validation WER increases #933

Open anderleich opened 1 year ago

anderleich commented 1 year ago

Hi, I've followed the this blog post https://huggingface.co/blog/fine-tune-whisper to finetune Whisper with my own dataset.

Everything seems to be working as expected. However, I've noticed a strange behaviour. While validation loss keeps decreasing at every checkpoint, validation WER increases. As a consequence, the best model in terms of WER is obtained at early stages of the finetuning process. At inference time, the last saved checkpoint outperforms the best WER checkpoint though.

Do you have any clues of what is happening?

anderleich commented 1 year ago

image image image

osanseviero commented 1 year ago

cc @Vaibhavs10 @sanchit-gandhi

sanchit-gandhi commented 1 year ago

Hey @anderleich! Sorry for the late reply here. Could you share your training arguments so we can get a feel for the kind of set-up you're employing? And also your compute_metrics function? The fact that the eval loss is going down suggests that we're predicting the correct tokens, but not decoding them to words properly with the tokenizer.

Even better would be a Colab link that we could look through and reproduce ourselves :)

ammaraldirawi commented 1 year ago

I had the same issue and my problem was enabling bf16=True, instead of fp16=True, in Seq2SeqTrainingArguments

anderleich commented 1 year ago

Hi @sanchit-gandhi ,

Thanks for your response!

Indeed, just looking at the loss it seems the model is learning. In fact, it improves the 'openai/whisper-small` baseline when tested on several custom tests. So, the training process is working properly.

This is my compute_metrics function:

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processors.tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = processors.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processors.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

And these are my training arguments:

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_MODEL_DIR,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation,
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=20000,
    gradient_checkpointing=False,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    #generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    #eval_delay=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
    dataloader_num_workers=20
)
sanchit-gandhi commented 1 year ago

Thanks for sharing those additional details @anderleich! I can't see anything that looks wrong based on these, and given what you've said about your tests it certainly sounds like the model is learning correctly.

What I would suggest doing is removing these two args from the Seq2SeqTrainingArguments:

-   metric_for_best_model="wer",
-   greater_is_better=False,

What will then happen is that Trainer will select the model with the lowest eval loss as your 'best' model at the end of training (rather than eval WER).

Based on your logs, this should select for you the actual best model.

iamgroot42 commented 1 year ago

@anderleich I tried fine-tuning on LibriSpeech recently and realized that for WER computation (in the compute_metric function) you should normalize text (as is done in the model-card as an example)

pred_str  = model.tokenizer._normalize(pred_str)
label_str = model.tokenizer._normalize(label_str)

Hope it helps!

sanchit-gandhi commented 7 months ago

There's another example here for normalising that you can use: https://github.com/huggingface/community-events/blob/a2d9115007c7e44b4389e005ea5c6163ae5b0470/whisper-fine-tuning-event/run_speech_recognition_seq2seq_streaming.py#L514-L532

As well as some other valuable resources for fine-tuning: https://github.com/huggingface/community-events/tree/main/whisper-fine-tuning-event#tips-and-tricks