sanchit-gandhi / whisper-jax

JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
Apache License 2.0
4.24k stars 356 forks source link

Minimizing (repetitive) hallucinations and improving accuracy #148

Open gl-jkeys opened 8 months ago

gl-jkeys commented 8 months ago

Hello! This is a great repository, thank you very much @sanchit-gandhi!

We would like to use this repository in our system, but quite a few of our Word-Error Rate (WER) regression tests fail when used with the closest equivalent settings as compared to faster-whisper, as the backend for our transcription system. That is to the say, for the same model size and beam size and with (non-word-level) timestamps enabled, whisper-jax infers with a higher word-error rate for a significant proportion of our test files.

One of the main things that significantly affects WER, in my testing with longer English-only files, is repetitive hallucinations. It seems that oftentimes, whisper-jax will find itself in repetitive hallucination loops before eventually breaking out and inferring the next actual speech in the file.

I think that whisper-jax is missing a few features offered by faster-whisper that might be contributing to this disparity in WER, primarily (1) word-level timestamps and (2) voice-activity detection filtering (as in faster-whisper's implementation).

With all of the above in mind, I have a few questions regarding minimizing hallucinations and overall improving the accuracy of the inferences.

  1. Do you have any suggestions for how to go about porting word-level timestamps to this repository? I would be willing to take a crack at it, but might need some guidance on porting the word timestamps algorithm from the HuggingFace whisper implementation. Related: #37
  2. For the voice activity detection filtering, would that be an appropriate addition to this repository? We have already adapted faster-whisper's algorithm for VAD clipping audio, before submitting it to whisper-jax in our system. We can submit a PR to this repository to add the functionality, if it aligns with the goals of this repository. In our testing, we have found that the VAD filtering significantly reduces hallucinations caused by silence and/or non-speech sound or noise. (It also offers a nice speed-up for files with a lot of silence.)
  3. Are there any other missing features or disparities that might be affecting Word-Error Rate that can be ported here?
  4. Are there optimal settings for reducing repetitive hallucinations?
  5. Is it possible to enable condition_on_previous_text=False with this repository?

Thank you again for this fantastic repository!

sanchit-gandhi commented 6 months ago

Hey @gl-jkeys! Glad to hear this has been a useful resource, and apologies for the late reply.

  1. Word-level timestamps are possible through the dynamic time warping (DTW) algorithm, but first need to be ported from PyTorch to Flax: https://github.com/huggingface/transformers/blob/6af3ce7757e87e7e3380b0405bd0757805d41182/src/transformers/models/whisper/modeling_whisper.py#L2571-L2576 I think this should be quite straightforward since the DTW algorithm is written in numpy, so we'd just need to convert it all to jax.numpy (so as not to have a CPU operation in the modelling code, which would require a sync-point when we jit the forward pass)
  2. Very cool! What's the VAD algorithm you've used? Is it based on just the outputs of Whisper? Or an additional model? The idea of this repository is that it's meant to be as streamlined as possible, so if it just relies on Whisper it could be a nice addition
  3. Beam search will help, at the expense of slower inference. You can see code for this here: https://github.com/huggingface/distil-whisper/blob/main/training/flax/distil_whisper/pipeline.py
  4. You can try a lower chunk length, e.g. 25 seconds, since this sometimes help reduce hallucinations (especially when used with distilled models) https://github.com/sanchit-gandhi/whisper-jax/blob/9c50a6ee5f30f6429ad46dc748603296dfb3484b/whisper_jax/pipeline.py#L424 I would first try this modification, since it's super easy to set and observe the results. You can try a few different values for the chunk length, e.g. [25, 27.5, 30], and see which works best.
gl-jkeys commented 5 months ago

Hey @sanchit-gandhi, sorry for the extremely delayed response. I let it go over the holidays, and then forgot to come back to this message afterward. Apologies!

We decided not to use whisper-jax in our pipeline, largely because this repo is effectively archived. Thankfully, you released distil-whisper, and we are actively investigating dropping the distilled implementation into our pipeline! I respond point-by-point below, though, in case it helps anyone else who stumbles upon this repo and really wants to use TPUs.

  1. Perfect, thanks.
  2. We're using the exact same algorithm that faster-whisper uses, which relies on Silero VAD. It's an interesting idea to use a non-ensemble approach, or more specifically, use the same model to accomplish both tasks; I can't immediately think of how one would do that without doing two full inference passes, however. [Would you expect a slimmed down version of Whisper that does not infer the transcript to be significantly faster in the decoding phase?]
  3. Thank you.
  4. A data point: in our investigation of distil-whisper, we are using your recommended chunk length of 15s, and there are comparatively few hallucinations in that repo/implementation (using distil-medium.en).