huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.33k stars 238 forks source link

Fix total_train_steps issue related to gradient_accumulation_steps #51

Closed bofenghuang closed 6 months ago

bofenghuang commented 6 months ago

Hi @sanchit-gandhi,

Appreciate your fantastic work and for sharing the code!

Made a tiny modification here -- dividing steps_per_epoch by gradient_accumulation_steps to get the right effective steps (w. parallel, distributed & accumulation).

sanchit-gandhi commented 6 months ago

Thanks for your kind words @bofenghuang! I believe this is ok as is, since we only count a training step when the accelerator performs a gradient update? i.e. every gradient_accumulation_steps: https://github.com/huggingface/distil-whisper/blob/e2139a79993b4a96680c383b4a0ba60ec586104d/training/run_distillation.py#L1498-L1501

bofenghuang commented 6 months ago

Thanks for your response! In my understanding, as we count a training step only when the accelerator performs a gradient update, we need to consider gradient accumulation when calculating the total training steps (by steps_per_epoch here).

Therefore, steps_per_epoch should be n_training_examples / train_batch_size_per_device / n_gpus / gradient_accumulation_steps