roudimit / whisper-flamingo

[Interspeech 2024] Whisper-Flamingo: Integrating Visual Features into Whisper for Audio-Visual Speech Recognition and Translation
https://arxiv.org/abs/2406.10082
Other
87 stars 4 forks source link

Whisper-Flamingo

Updates

Nov 19, 2024: We achieved SOTA ASR (1.3\%) and SOTA AVSR (1.4\%) on LRS2 - checkpoints are released below.
Oct 11, 2024: We achieved SOTA ASR (0.68% WER) and SOTA AVSR (0.72% WER) on LRS3 by training on LRS3 and VoxCeleb2 - checkpoints are released below.

Introduction

Integrating Visual Features into Whisper for Audio-Visual Speech Recognition and Translation

We propose Whisper-Flamingo which integrates visual features into the Whisper speech recognition and translation model with gated cross attention. Our audio-visual Whisper-Flamingo outperforms audio-only Whisper on English speech recognition and En-X translation for 6 languages in noisy conditions. Moreover, Whisper-Flamingo is a versatile model and conducts all of these tasks using one set of parameters, while prior methods are trained separately on each language.

Whisper-Flamingo

Video Demos

Check out the video demo below (turn sound on). We made several videos about Whisper-Flamingo:

Colab Demos

We support two colab demos (local copies in ./notebooks):

Virtual Environment for Training and Testing

Since this project uses the MuAViC dataset, we base our virtual environment on theirs.

Create a fresh virtual environment:

conda create -n whisper-flamingo python=3.8 -y
conda activate whisper-flamingo

Clone MuAViC repo and install their requirements:

conda install -c conda-forge ffmpeg==4.2.2 -y
conda install -c conda-forge sox -y
git clone https://github.com/facebookresearch/muavic.git muavic-setup
cd muavic-setup
pip install -r requirements.txt
cd ..

Clone the "muavic" branch of av_hubert's repo and install Fairseq:

git clone -b muavic https://github.com/facebookresearch/av_hubert.git
cd av_hubert
git submodule init
git submodule update
# Install av-hubert's requirements
pip install -r requirements.txt
# Install fairseq
cd fairseq
pip install --editable ./
cd ../..

Install extra packages used in our project:

pip install tiktoken==0.5.2 pytorch-lightning==2.1.3 numba==0.58.1 transformers==4.36.2 evaluate tensorboardX

Download and prepare data

LRS3 / MuAViC: We provide all data to reproduce the results on the test set. For instructions on how to prepare the LRS3 training set (and more details about the test noise), see preparation/README.md.

Download and extract our resources:

wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/muavic.tar.gz
wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/noise.tar.gz
tar -xf muavic.tar.gz
tar -xf noise.tar.gz
echo $(pwd)/noise/babble/muavic/babble_all.wav > ./noise/babble/muavic/test.tsv
echo $(pwd)/noise/babble/muavic/babble_all.wav > ./noise/babble/muavic/valid.tsv
echo $(pwd)/noise/babble/lrs3/noise.wav > ./noise/babble/lrs3/test.tsv
echo $(pwd)/noise/babble/lrs3/noise.wav > ./noise/babble/lrs3/valid.tsv

LRS2: The data can be downloaded here after signing a license and sending it to the BBC (helper script: notebooks/lrs2_download.ipynb). In our experience, it took a week to receive the username & password for the data download. We used the AutoAVSR scripts to process LRS2 (using the provided facial landmarks). Finally, the AutoAVSR data lists must be converted to AV-HuBERT / Fairseq manifests. We provide a script to do this (notebooks/lrs2_make_tsv.ipynb).

Pre-trained Models

We release our pre-trained models (GPUs = GPUs used for training).

Audio-only Whisper (fine-tuned on LRS3 / MuAViC)

Mod. Size VoxCeleb2 Parameters En ASR En-X ST GPUs Download Link
A Large-V2 yes 1,550M yes no 1x A6000, 48GB noisy: whisper_en_large_vc2_noisy
clean: whisper_en_large_vc2_clean
A Large-V2 no 1,550M yes no 1x A6000, 48GB whisper_en_large
A Large-V2 no 1,550M yes yes 4x A6000, 48GB whisper_en-x_large
A LRS2-Medium no 769M yes no 1x A6000, 48GB whisper_lrs2_medium
A Medium no 769M yes yes 4x A5000, 24GB whisper_en-x_medium
A Small no 244M yes yes 4x A5000, 24GB whisper_en-x_small

Audio-visual Whisper-Flamingo

Mod. Size VoxCeleb2 Parameters En ASR En-X ST GPUs Download Link
AV Large-V2 yes 2,497M yes no 1x A6000, 48GB noisy: whisper-flamingo_en_large_vc2_noisy
clean: whisper-flamingo_en_large_vc2_clean
AV Large-V2 no 2,497M yes no 1x A6000, 48GB whisper-flamingo_en_large
AV Large-V2 no 2,497M yes yes 4x A6000, 48GB whisper-flamingo_en-x_large
AV LRS2-Medium no 1,390M yes no 1x A6000, 48GB whisper-flamingo_lrs2_medium
AV Medium no 1,390M yes yes 4x A6000, 48GB whisper-flamingo_en-x_medium
AV Small no 651M yes yes 4x A5000, 24GB whisper-flamingo_en-x_small

Decoding Script

The script whisper_decode_video.py is used for decoding both audio-only Whisper models and audio-visual Whisper-Flamingo models. We also provide a SLURM scripts to run decoding in parallel, see the next section for details.

Audio-Only Decoding

Download our audio-only Whisper model fine-tuned for En-X translation.

mkdir models
wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/whisper_en-x_small.pt -P models

Decode an audio-only model (see whisper_decode_video.py for argument details):

LRS2 ASR decoding (adjust noise_snr as desired):

python -u whisper_decode_video.py --lang lrs2 \
                                --model-type medium \
                                --noise-snr 1000 \
                                --noise-fn noise/babble/lrs3/test.tsv \
                                --beam-size 1 \
                                --modalities asr \
                                --fp16 1 \
                                --checkpoint-path models/whisper_lrs2_medium.pt \
                                --decode-path decode/

Audio-Visual Decoding

Download our audio-visual Whisper-Flamingo model fine-tuned for En-X translation. Note: the AV-HuBERT weights must be downloaded and are used by Fairseq to load the architecture.

mkdir models
wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/whisper-flamingo_en-x_small.pt -P models
wget https://data.csail.mit.edu/public-release-sls/whisper-flamingo/models/large_noise_pt_noise_ft_433h_only_weights.pt -P models

Decode an audio-visual model:

python -u whisper_decode_video.py --lang en \
                                --model-type small \
                                --noise-snr 0 \
                                --noise-fn noise/babble/muavic/test.tsv \
                                --beam-size 1 \
                                --modalities avsr \
                                --use_av_hubert_encoder 1 \
                                --av_fusion separate \
                                --fp16 1 \
                                --checkpoint-path models/whisper-flamingo_en-x_small.pt \
                                --decode-path decode/ \
                                --av-hubert-path av_hubert/avhubert/ \
                                --av-hubert-ckpt models/large_noise_pt_noise_ft_433h_only_weights.pt

LRS2 AVSR decoding (adjust noise_snr as desired):

python -u whisper_decode_video.py --lang lrs2 \
                                --model-type medium \
                                --noise-snr 1000 \
                                --noise-fn noise/babble/lrs3/test.tsv \
                                --beam-size 1 \
                                --modalities avsr \
                                --use_av_hubert_encoder 1 \
                                --av_fusion separate \
                                --fp16 1 \
                                --checkpoint-path models/whisper-flamingo_lrs2_medium.pt \
                                --decode-path decode/ \
                                --av-hubert-path av_hubert/avhubert/ \
                                --av-hubert-ckpt models/large_noise_pt_noise_ft_433h_only_weights.pt

Decoding Script in Parallel with SLURM

We provide slurm/whisper_decode_video_slurm_wrapper.sh which submits decoding jobs tp SLURM for a given checkpoint to test all En-X languages in both clean / noisy conditions. Please modify slurm/whisper_decode_video_slurm.sh to match your SLURM environment.

After submitting all jobs with source slurm/whisper_decode_video_slurm_wrapper.sh, use slurm/check_results.ipynb to print the results of all decoding runs. It will load the decoding WER / BLEU scores and print them in a convinient table.

Training

Step 1: Fine-tune audio-only Whisper for En-X translation on MuAViC

First, in config/audio/audio_en-x_large.yaml, replace noise_fn: '/data/sls/scratch/roudi/datasets/musan/tsv/all/train.tsv' with the path to your training noise. Command:

python -u whisper_ft_muavic.py config/audio/audio_en-x_large.yaml

We also provide a slurm script in slurm/train_audio_4gpu.sh. It took about 2-3 days to fine-tune Whisper Large-V2 on our GPUs. The medium and small models are faster take less time to train.

Step 2: Train audio-visual Whisper-Flamingo with gated cross attention

Once the audio model is fine-tuned, we freeze the weights and insert the gated cross-attention layers to train the audio-visual Whisper-Flamingo. Command:

python -u whisper_ft_muavic_video.py config/audio-visual/av_en-x_large.yaml

We also provide a slurm script in slurm/train_video_4gpu.sh. Training Whisper-Flamingo is faster since the cross-attention layers are the only trainable layers. It took about 1 day to train Whisper-Flamingo Large on our GPUs (not including the time to fine-tune the audio model in the first step.).

Training progress

Model weights will be saved in models/checkpoint. Tensorboard can be opened to monitor several metrics.

cd slurm
tensorboard --logdir .  --port 6008

Training notes

Acknowledgments

This code based is based on the following repos: Whisper Fine-Tuning Demo, Whisper, AV-HuBERT, MuAViC, ESPnet, AutoAVSR, Flamingo-pytorch.

License

Our work is licensed under BSD-3. However, please check the licenses of the works we build on, including AV-HuBERT.

Citation

@inproceedings{rouditchenko24_interspeech,
  title     = {Whisper-Flamingo: Integrating Visual Features into Whisper for Audio-Visual Speech Recognition and Translation},
  author    = {Andrew Rouditchenko and Yuan Gong and Samuel Thomas and Leonid Karlinsky and Hilde Kuehne and Rogerio Feris and James Glass},
  year      = {2024},
  booktitle = {Interspeech 2024},
  pages     = {2420--2424},
  doi       = {10.21437/Interspeech.2024-322},
  issn      = {2958-1796},
}