desh2608 / gss

A simple package for Guided source separation (GSS)
MIT License
105 stars 13 forks source link

Batch processing #12

Closed desh2608 closed 1 year ago

desh2608 commented 1 year ago

Currently, each cut (segment) is processed one at a time. This is fine for long segments which usually occupy the entire GPU memory, but may be wasteful for shorter segments. The conventional method (e.g. in ASR) is to do mini-batch processing by padding several segments to the same length. However, we have to be careful doing that here because of 2 reasons:

  1. The CACGMM implementations are currently for 3-dimensional input (channels, time, frequency), and adding a batch dimension would require modifying a lot of internal implementation (which is done efficiently through einops).
  2. The CACGMM inference step computes sums over the whole time duration, so adding padding would require some kind of masking.

For these reasons, it may be better to a different kind of "batching". Instead of combining segments in parallel, we can combine them sequentially --- but only if they are from the same recording and have the same target speaker. This is to ensure that we do not create a permutation problem in the mask estimation.

If we combine in this way, we can even remove individual contexts from each segment, and instead add the context to the combined "super-segment", which would further save compute/memory.