google-deepmind / dm_aux

Apache License 2.0
62 stars 6 forks source link

When to use dm_aux's stft versus jax.scipy.signal.stft? #2

Open DBraun opened 8 months ago

DBraun commented 8 months ago

I think that it makes sense to use a "fast" fourier transform (jax.scipy.signal.stft which seems to use jax.numpy.fft.fft) when the window size is sufficiently large, which is usually the case in audio. However, dm_aux.spectral.stft uses a dense DFT matrix: https://github.com/google-deepmind/dm_aux/blob/77f5ed76df2928bac8550e1c5466c0dac2934be3/dm_aux/spectral.py#L546

What's the reasoning for this? I ask because I was using jax.scipy.signal.stft but was getting warnings like this:

2024-02-20 19:27:45.321748: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng28{k2=3,k3=0} for conv (f32[4,1,18432]{2,1,0}, u8[0]{0}) custom-call(f32[4,2048,20479]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...
2024-02-20 19:27:56.297800: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 11.976181386s
Trying algorithm eng28{k2=3,k3=0} for conv (f32[4,1,18432]{2,1,0}, u8[0]{0}) custom-call(f32[4,2048,20479]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...

When I replaced the jax.scipy.signal.stft with an equivalent call to dm_aux.spectral.stft there's no error. This is using WSL 2 with an Nvidia RTX 2080.

Name: jax
Version: 0.4.24
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, jaxloudnorm, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.24+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
---
Name: flax
Version: 0.8.1
Summary: Flax: A neural network library for JAX designed for flexibility
Home-page:
Author:
Author-email: Flax team <flax-dev@google.com>
License:
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: jax, msgpack, numpy, optax, orbax-checkpoint, PyYAML, rich, tensorstore, typing-extensions
Required-by: clu

Update: I think I realized that an FFT is better than a matrix multiplication when both have to be done on the CPU. If you have a GPU, then the matrix multiplication is faster, just due to hardware design and parallelization. So that's why dm_aux uses a DFT matrix.