snakers4 / silero-vad

Silero VAD: pre-trained enterprise-grade Voice Activity Detector
MIT License
3.38k stars 353 forks source link

Feature request - DO NOT disable PyTorch gradient globally when using PyTorch JIT model #460

Closed gau-nernst closed 2 weeks ago

gau-nernst commented 2 weeks ago

🚀 Feature

Thank you for open sourcing this awesome model. However, disabling PyTorch gradient globally is very intrusive.

https://github.com/snakers4/silero-vad/blob/82342b8a4ce695d013c45bee429bd7df455f9849/utils_vad.py#L159

May I suggest to remove this line. And instead, decorate get_speech_timestamps() with @torch.no_grad()? (and perhaps also for VADIterator.__call__())

Motivation

I'm using Silero VAD as part of my training pipeline. Disabling PyTorch gradient globally messes up with my training. Of course I could circumvent that by enable it again and call Silero VAD inside torch.no_grad() context manager, or use the ONNX version, but it is better to not have this behavior in the first place.

snakers4 commented 2 weeks ago

Looks like a good feature. Can you submit a PR maybe?

snakers4 commented 2 weeks ago

Also how exactly do you use it in the training loop?

gau-nernst commented 2 weeks ago

I was experimenting with training an audio classification model on silence parts of the audio. So I use VAD to detect and remove speech on-the-fly inside dataloader as a data pre-processing step.