What's the reasoning for this? I ask because I was using jax.scipy.signal.stft but was getting warnings like this:
2024-02-20 19:27:45.321748: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng28{k2=3,k3=0} for conv (f32[4,1,18432]{2,1,0}, u8[0]{0}) custom-call(f32[4,2048,20479]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...
2024-02-20 19:27:56.297800: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 11.976181386s
Trying algorithm eng28{k2=3,k3=0} for conv (f32[4,1,18432]{2,1,0}, u8[0]{0}) custom-call(f32[4,2048,20479]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...
When I replaced the jax.scipy.signal.stft with an equivalent call to dm_aux.spectral.stft there's no error. This is using WSL 2 with an Nvidia RTX 2080.
Update: I think I realized that an FFT is better than a matrix multiplication when both have to be done on the CPU. If you have a GPU, then the matrix multiplication is faster, just due to hardware design and parallelization. So that's why dm_aux uses a DFT matrix.
I think that it makes sense to use a "fast" fourier transform (
jax.scipy.signal.stft
which seems to usejax.numpy.fft.fft
) when the window size is sufficiently large, which is usually the case in audio. However,dm_aux.spectral.stft
uses a dense DFT matrix: https://github.com/google-deepmind/dm_aux/blob/77f5ed76df2928bac8550e1c5466c0dac2934be3/dm_aux/spectral.py#L546What's the reasoning for this? I ask because I was using
jax.scipy.signal.stft
but was getting warnings like this:When I replaced the
jax.scipy.signal.stft
with an equivalent call todm_aux.spectral.stft
there's no error. This is using WSL 2 with an Nvidia RTX 2080.Update: I think I realized that an FFT is better than a matrix multiplication when both have to be done on the CPU. If you have a GPU, then the matrix multiplication is faster, just due to hardware design and parallelization. So that's why dm_aux uses a DFT matrix.