google / uis-rnn

This is the library for the Unbounded Interleaved-State Recurrent Neural Network (UIS-RNN) algorithm, corresponding to the paper Fully Supervised Speaker Diarization.
https://arxiv.org/abs/1810.04719
Apache License 2.0
1.55k stars 320 forks source link

Batch prediction? - or allow prediction using multiprocessing #32

Closed hbredin closed 4 years ago

hbredin commented 5 years ago

Describe the question

Documentation states that one can only apply prediction one sequence at a time.

On my machine (with GPU), it takes more than 10s to process one sequence with 100 samples. Would be nice to support batch prediction to make processing large collection of sequences faster. Beam search probably makes it impossible though.

Anyway, thanks for open-sourcing this implementation. This is really appreciated!

My background

Have I read the README.md file?

Have I searched for similar questions from closed issues?

Have I tried to find the answers in the paper Fully Supervised Speaker Diarization?

Have I tried to find the answers in the reference Speaker Diarization with LSTM?

Have I tried to find the answers in the reference Generalized End-to-End Loss for Speaker Verification?

wq2012 commented 5 years ago

Thanks for your interest in our work!

Indeed, the decoding algorithm is nearly impossible to parallelize, especially due to the beam search and look aheads...

I also don't have much experience in designing parallel decoding algorithms. Maybe it's still possible but we don't have the bandwidth to explore this direction.

My suggestion is to use the multiprocessing features in Python to run prediction.

wq2012 commented 5 years ago

Actually you made a good point.

Let me add multiprocessing support within the predict() function, and allow user to config how many processes to use.

hbredin commented 5 years ago

Awesome. Thanks!

wq2012 commented 5 years ago

I've tried for a while for multiprocessing and torch.multiprocessing, but got stuck with deadlocks.

PyTorch doesn't work well with multiprocessing: https://pytorch.org/docs/stable/notes/multiprocessing.html

I don't have other good solutions for now.

hbredin commented 5 years ago

Oh. Crap. Thanks for trying anyway!

chrisspen commented 4 years ago

Would this fork be relevant, that adds some support for multiprocessing?

Also, is there anyway to split apart files, run separate instances of UIS-RNN on each file in parallel, and then combine the segment labels back together? The main limitation of the algorithm is that, even with an GPU, it takes a very long time.

wq2012 commented 4 years ago

@DonkeyShot21 Did you try to benchmark the speed of using torch.multiprocessing vs not using it (on CPU and GPU)?

I tried on CPU. Using torch.multiprocessing is even slower (likely due to context switch).

DonkeyShot21 commented 4 years ago

I did not try on CPU. However, on GPU (1080ti and 8core CPU) it is around 4x faster using multiprocessing in my tests. Be aware that both multiprocessing and torch.multiprocessing are a bit buggy and can cause memory leak if they are not well setup. Also, using mp.get_context('forkserver') in multiprocessing seems to increase performance quite significantly. BTW I was not able to implement multiprocessing as a feature "inside" the UISRNN class, because of how multiprocessing works (uses pickle).

In general, I think the best way to speed up is using a bounded number of speakers, as reported in your paper.

wq2012 commented 4 years ago

@DonkeyShot21 OK, thanks for your information.

Currently I don't have machines to test whether this will make GPU prediction faster. Submitted a uisrnn.parallel_predict() function in case people want to test it.

DonkeyShot21 commented 4 years ago

Just one note: opening and closing the Pool inside uisrnn.parallel_predict() could cause memory leak if called multiple times. It is not a problem in this specific implementation but be aware that .close() does not work as one would think. For more info see this issue

chrisspen commented 4 years ago

I tested this on and GPU-equipped EC2 instance and also found it was 3-4x faster than the plain CPU version.

wq2012 commented 4 years ago

@DonkeyShot21 Thanks for letting me know. Indeed very weird behavior of PyTorch...

@chrisspen Cool! Glad to know that it worked.