huggingface / distil-whisper

Distilled variant of Whisper for speech recognition. 6x faster, 50% smaller, within 1% word error rate.
MIT License
3.33k stars 238 forks source link

[Training]: Configure pseudo-labels or text labels as targets in KD #56

Closed sanchit-gandhi closed 6 months ago

sanchit-gandhi commented 6 months ago

Updates the distillation script, such that the user can configure between training on pseudo-labels or text labels as their targets. This enables the user to run "knowledge distillation", but directly on the text labels provided by an ASR dataset, such as Common Voice. The pseudo-labelling step can effectively be skipped when training like this.

To train directly on the text labels provided in the dataset, set --use_pseudo_labels=False and pass the correct --text_column_name for your text targets. For example, training distil-large-v3 on the Common Voice 15 dataset with no pseudo-labels:

#!/usr/bin/env bash

accelerate launch --mixed_precision=bf16 run_distillation.py \
  --model_name_or_path "./distil-large-v3-init" \
  --teacher_model_name_or_path "openai/whisper-large-v3" \
  --train_dataset_name "mozilla-foundation/common_voice_15_0" \
  --train_dataset_config_name "de" \
  --train_split_name "train" \
  --text_column_name "sentence" \
  --eval_dataset_name "mozilla-foundation/common_voice_15_0" \
  --eval_dataset_config_name "de" \
  --eval_split_name "validation" \
  --eval_text_column_name "sentence" \
  --eval_steps 5000 \
  --save_steps 5000 \
  --warmup_steps 500 \
  --learning_rate 1e-4 \
  --lr_scheduler_type "linear" \
  --logging_steps 25 \
  --save_total_limit 1 \
  --max_steps 50000 \
  --per_device_train_batch_size 64 \
  --per_device_eval_batch_size 64 \
  --dataloader_num_workers 16 \
  --preprocessing_num_workers 16 \
  --ddp_timeout 7200 \
  --dtype "bfloat16" \
  --output_dir "./" \
  --use_pseudo_labels "false" \
  --condition_on_prev_probability "0.0" \
  --do_train \
  --do_eval \
  --gradient_checkpointing \
  --overwrite_output_dir \
  --predict_with_generate \
  --freeze_encoder \
  --streaming

=> we use the same target labels as fine-tuning, but with the teacher influence during training with the KL-div loss.