Audio-WestlakeU / FullSubNet

PyTorch implementation of "FullSubNet: A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."
https://fullsubnet.readthedocs.io/en/latest/
MIT License
530 stars 152 forks source link

Real-time streaming Fast FullSubNet (LSTMCell) #67

Open fronx opened 6 months ago

fronx commented 6 months ago

I'm trying to run Fast FullSubNet in a real-time audio streaming context.

I've successfully trained a model that seems to work reasonably well in a non-streaming context: https://github.com/fronx/FullSubNet/releases/tag/fast118

However, the latency of running it in that way is too high. I've tried turning down the hop length, but it just leads to choppy, unintelligible noise. So I looked around and apparently the structure of the code needs to be changed quite a bit for that to work?

I'm happy to execute the change and contribute it to this repo, but I might need a little bit more guidance so I don't go off track. I know how to program, but I'm still fairly new to audio ML.

Gathered instructions

For reference, to have everything in one place, here are instructions I gathered from older issues:

there are two things you need to do are changing the torch.nn.LSTM to torch.nn.LSTMCell and adding a for-loop.

as you can see, for performance purposes, cumulative norm that I released is written in a compact style, i.e., in advance computing the statistical mean value of all frames for an utterance. You should separate this function using a frame-wise style. The point basically is to ensure that normalizing the current frame using the statistical mean value of previous all frames.

You may use a for-loop like here:

hx, cx = load(hidden_state)
rnn = nn.LSTMCell(dims)

output = []
for samples in (all_samples, step=hop_len):
    frame = fft(samples)
    frame = cum_norm(frame)
    hx, cx = rnn(frame, (hx, cx))
    output.append(ifft(hx))

overlapped_add(output)

You could check out here for the difference between the LSTMCell and LSTM.

Questions

  1. It looks like you suggest changing the model input from a magnitude spectrogram ([B, 1, F, T]) to an array of samples. Is that necessary? Wouldn't that require completely retraining the model from scratch?
  2. The pseudocode above doesn't mention MEL scaling, which is necessary for Fast FullSubNet. I assume that should also be applied per frame?
  3. Does the change from LSTM to looping over LSTMCell require retraining?

Thanks in advance for any hints you can provide. Would be nice if we could get this repo into usable shape for streaming inference in a way that's shareable with the world. 🤩

fronx commented 6 months ago

More details on the model latency:

How did you calculate your RTFs?

fronx commented 6 months ago

Good news: I updated my operating system to Sonoma 14.3.1 and that fixed it, without any further code changes. Now the processing time is consistently between 13ms and 15ms.

fronx commented 6 months ago

For potential future reference, here's the torch.profiler output of a single inference run:

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
              model_inference         5.41%     780.000us       100.00%      14.412ms      14.412ms             1
                   aten::lstm         6.87%     990.000us        86.02%      12.397ms       2.479ms             5
                 aten::linear         1.02%     147.000us        39.57%       5.703ms     126.733us            45
                  aten::addmm        29.68%       4.277ms        35.56%       5.125ms     122.024us            42
               aten::sigmoid_        15.89%       2.290ms        15.89%       2.290ms      21.204us           108
                  aten::tanh_         8.48%       1.222ms         8.48%       1.222ms      33.944us            36
                   aten::tanh         7.88%       1.136ms         7.88%       1.136ms      31.556us            36
                  aten::copy_         6.45%     930.000us         6.45%     930.000us      17.222us            54
                   aten::add_         5.68%     818.000us         5.68%     818.000us      10.907us            75
                 aten::matmul         0.15%      22.000us         2.46%     354.000us      88.500us             4
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 14.412ms

And here's a pretty timeline view:

Screenshot 2024-02-16 at 22 14 17