Blair-Johnson / batch-whisper

Batch Support for OpenAI Whisper
MIT License
85 stars 22 forks source link

Bottleneck in mel / segment calculation? #6

Open sbuser opened 1 year ago

sbuser commented 1 year ago

I gave this a try today on a large number of audio files with batch sizes of 16 and then 32.

If I watch nvidia-smi dmon while the batch processing runs, I see very low GPU memory/core usage for large periods of time, then a spike in both, followed by a batch of text.

I speculated that there was a bottleneck in the decoding of the audio files into mels - but it looks to me like you do that at the start. I'm not familiar enough with what is calculated on and off the GPU and when here. Is it possible there's a lot of work involved decoding, eg, an MP3 that can be threaded that I'm missing somewhere? GPU utilization is only near maximum for maybe like 5% of the time the program is running which isn't my typical experience with ML resource utilization?

Blair-Johnson commented 1 year ago

You're absolutely right that there's an audio decoding bottleneck. Currently the audio files are decoded serially at the beginning of a batch with ffmpeg. This could and should be parallelized, as it can take a long time for long (or many) audio files.

The GPU usage is actually quite descent once the audio files are decoded and the system begins transcription, but that part of the pipeline can't start until the tensor arrays with the decoded audio are prepared. Other parts of the pipeline definitely warrant investigation, but I think this explains what you're observing.

Would you be interested in submitting a PR? The fix should be as simple as replacing the serial decoding of a batch of audio file-paths here with a multiprocessing implementation: https://github.com/Blair-Johnson/batch-whisper/blob/76b3f818fb94a7c18c37d2ffecf71d7c69182749/whisper/transcribe.py#L335

sbuser commented 1 year ago

I can probably do that for the calculation of mels, but that isn't where I'm seeing the bottleneck as I run transcribe - which was confusing to me. If that's the only place audio decoding is taking place, then something else is CPU bound that I'm not understanding.

I didn't do a good job of explaining. When run in verbose mode with a batch size of 32, text is output let's say every 30 seconds or so. nvidia-smi dmon during that 30 second window looks something like below. Then there will be a spike of GPU usage (the 99% here and 37% memory), and then some text will pop out.

Edit: my paste from nvidia-smi doesn't format well here, but basically a single GPU sits at 15% core and 0% memory usage for 30 seconds, then spikes to 99% core and ~40% memory usage for a single second, then text pops out.

Meanwhile I a show a single 5ghz CPU core pegged. So I'm trying to figure out what is CPU bound for the 30 seconds that prevents the GPU from working at 99% core usage and 37% memory more of the time.

If it was just the initial calculation of mels, wouldn't I only see that spike for new files? This pattern happens in the middle of a group of 32 files - then, yes, there is a slowdown while mels are calculated initially for new files - but these are long periods of time in the middle of already decoded files?

I don't really understand the model and what's being calculated where but thought you might and could point me in the right direction?

sbuser commented 1 year ago

Here's one slow portion, I think:

https://github.com/Blair-Johnson/batch-whisper/blob/76b3f818fb94a7c18c37d2ffecf71d7c69182749/whisper/decoding.py#L704-L707

As it's looping through samples, this is particularly slow but doesn't seem to be part of the model inference or taking place on a GPU that I can tell? I'm not familiar enough with the where the matrix math gets done here.

I'll keep investigating, but this appears to be very slow relative to the surrounding work.

Blair-Johnson commented 1 year ago

Good catch, I think that I found the issue. On this line... https://github.com/Blair-Johnson/batch-whisper/blob/76b3f818fb94a7c18c37d2ffecf71d7c69182749/whisper/decoding.py#L493 ...the logit filter lists are are all initialized with the same list object. As a result, each new batch of tokens for each audio file was being filtered by all logit filters introducing redundancy and potential accuracy issues. I'll post a PR with that fix and the ffmpeg multiprocessing modification. Let me know if that improves performance on your end.

*edit: I also noticed that the ApplyTimestampRules filter was much slower than the others, it might be work taking a look at down the line. You'll likely still see a high load on one core handling the main thread with logit filtering and decoding (which seem to be the slowest parts of the pipeline apart from model inference). Using the large model with 6 >1hr podcasts on a V100, I'm getting 85-100% GPU usage with periodic drops to zero (which happens when new segments are fetched for model inference). I was getting 60-75% previously, so it looks like a good improvement.

sbuser commented 1 year ago

Ah, That probably also explains why I have a list of no_speech_prob in the output as opposed to just the single probability I expected and get with a normal whisper run.

sbuser commented 1 year ago

I think the large disparity between our resource use was because I was using model=base and you were using large. When I switch to large, I get much higher GPU utilization which makes sense. ApplyTimestampRules is definitely the bottleneck with those lower parameter models.

Your fix here did raise GPU utilization but in the results I still have an array of no_speech_prob corresponding to the current batch size (which changes as various audio files fall off from length differences).

I'll open another issue if I pin this down but there's also something going on with multiple model.decode() runs for each decode_with_fallback() toward the end of the batch - probably something to do with shifting file sizes. I can't tell yet if there's missing transcriptions and it's just falling back or what.

Blair-Johnson commented 1 year ago

I updated the PR to address the no_speech_prob issue, let me know if you have any issues. As for the multiple model.decode() executions, could you verify whether they are being called because of temperature fallback? The loop checks to see if the compression ratio of generated text is too high to avoid repetitive transcription, which can happen sometimes. If it's being triggered unnecessarily regardless of the threshold we set, then that's an issue.