RoyChao19477 / SEMamba

This is the official implementation of the SEMamba paper. (Accepted to IEEE SLT 2024)
Other
143 stars 14 forks source link

Inference speed on a small inputs #4

Closed iissme closed 3 weeks ago

iissme commented 4 months ago

Hello. Thanks for the code! I tested the model on audio data of various lengths and noticed that the inference for short audio is slower than for long audio. In my case, the intended data ranges from 0.5 to 1 second, but the model processes it quite slowly. Adding zero padding up to 30 seconds speeds up the process, but the inference time still doesn't go below 2 seconds.

Could you explain why this happens and if there is anything I can do to speed up the inference on short audio?

Below are the timings for different segments taken from the beginning of the same audio file:

1s

INFO:__main__:1 - 0.032488 in seconds.
INFO:__main__:2 - 0.016094 in seconds.
INFO:__main__:3 - 5.410543 in seconds.
INFO:__main__:4 - 0.015493 in seconds.

5s

INFO:__main__:1 - 0.033495 in seconds.
INFO:__main__:2 - 0.014399 in seconds.
INFO:__main__:3 - 4.745659 in seconds.
INFO:__main__:4 - 0.015998 in seconds.

10s

INFO:__main__:1 - 0.034424 in seconds.
INFO:__main__:2 - 0.014369 in seconds.
INFO:__main__:3 - 3.391307 in seconds.
INFO:__main__:4 - 0.017608 in seconds.

20s

INFO:__main__:1 - 0.033065 in seconds.
INFO:__main__:2 - 0.015286 in seconds.
INFO:__main__:3 - 2.050973 in seconds.
INFO:__main__:4 - 0.015754 in seconds.

30s

INFO:__main__:1 - 0.033277 in seconds.
INFO:__main__:2 - 0.014498 in seconds.
INFO:__main__:3 - 2.001130 in seconds.
INFO:__main__:4 - 0.016316 in seconds.
    with torch.no_grad():
        noisy_wav, _ = librosa.load('1.wav', sr=sampling_rate)
        noisy_wav = torch.FloatTensor(noisy_wav).to(device)
        noisy_wav = pad_audio(noisy_wav, 30, sampling_rate)

        before_time = time.perf_counter()
        norm_factor = torch.sqrt(len(noisy_wav) / torch.sum(noisy_wav ** 2.0)).to(device)
        torch.cuda.synchronize(device)
        logger.info('1 - %f in seconds.', time.perf_counter() - before_time)

        before_time = time.perf_counter()
        noisy_wav = (noisy_wav * norm_factor).unsqueeze(0)
        noisy_amp, noisy_pha, noisy_com = mag_phase_stft(noisy_wav, n_fft, hop_size, win_size, compress_factor)
        torch.cuda.synchronize(device)
        logger.info('2 - %f in seconds.', time.perf_counter() - before_time)

        before_time = time.perf_counter()
        amp_g, pha_g, com_g = model(noisy_amp, noisy_pha)
        torch.cuda.synchronize(device)
        logger.info('3 - %f in seconds.', time.perf_counter() - before_time)

        before_time = time.perf_counter()
        audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_factor)
        audio_g = audio_g / norm_factor
        torch.cuda.synchronize(device)
        logger.info('4 - %f in seconds.', time.perf_counter() - before_time)
RoyChao19477 commented 4 months ago

Hi iissme,

It’s quite unusual for short audio to be processed slower than longer ones. I replicated your timing measurements using a single RTX 4090 GPU, and here are my results:

0.5 Second:

INFO - 1 - 0.000331 in seconds.
INFO - 2 - 0.000562 in seconds.
INFO - 3 - 0.014941 in seconds.
INFO - 4 - 0.000750 in seconds.

1 Second:

INFO - 1 - 0.000222 in seconds.
INFO - 2 - 0.000513 in seconds.
INFO - 3 - 0.013630 in seconds.
INFO - 4 - 0.000649 in seconds.

10 Seconds:

INFO - 1 - 0.000251 in seconds.
INFO - 2 - 0.000518 in seconds.
INFO - 3 - 0.108044 in seconds.
INFO - 4 - 0.000846 in seconds.

100 Seconds:

INFO - 1 - 0.000486 in seconds.
INFO - 2 - 0.001104 in seconds.
INFO - 3 - 1.601276 in seconds.
INFO - 4 - 0.000918 in seconds.

code

with torch.no_grad():
    noisy_wav, _ = librosa.load(os.path.join( args.input_folder, fname ), sr=sampling_rate)
    noisy_wav = noisy_wav[:16000]           # 1 second samples
    noisy_wav = np.tile(noisy_wav, 100) # Repeat N times

    noisy_wav = torch.FloatTensor(noisy_wav).to(device)

    before_time = time.perf_counter()
    norm_factor = torch.sqrt(len(noisy_wav) / torch.sum(noisy_wav ** 2.0)).to(device)
    torch.cuda.synchronize(device)
    logger.info('1 - %f in seconds.', time.perf_counter() - before_time)

    before_time = time.perf_counter()
    noisy_wav = (noisy_wav * norm_factor).unsqueeze(0)
    noisy_amp, noisy_pha, noisy_com = mag_phase_stft(noisy_wav, n_fft, hop_size, win_size, compress_factor)
    torch.cuda.synchronize(device)
    logger.info('2 - %f in seconds.', time.perf_counter() - before_time)

    before_time = time.perf_counter()
    amp_g, pha_g, com_g = model(noisy_amp, noisy_pha)
    torch.cuda.synchronize(device)
    logger.info('3 - %f in seconds.', time.perf_counter() - before_time)

    before_time = time.perf_counter()
    audio_g = mag_phase_istft(amp_g, pha_g, n_fft, hop_size, win_size, compress_factor)
    audio_g = audio_g / norm_factor
    torch.cuda.synchronize(device)
    logger.info('4 - %f in seconds.', time.perf_counter() - before_time)

Given that Mamba is implemented in Triton and uses CUDA, the version of CUDA, the type of GPU, and the version of Mamba could impact performance. I suspect this issue may be related to either the package version or the GPU. It might be helpful to install Mamba directly with pip install . from the mamba_install directory.

Best regards, Roy Chao