google-deepmind / dm_aux

Apache License 2.0
62 stars 6 forks source link

istft has slow operation alarm / compilation warning #4

Open DBraun opened 8 months ago

DBraun commented 8 months ago

I'm trying to jit compile something that does an STFT, a phase modification, and an iSTFT. Here's the warning:

2024-02-25 13:28:21.380282: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng28{k2=3,k3=0} for conv (f32[1,1,46080]{2,1,0}, u8[0]{0}) custom-call(f32[1,2048,48127]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...
2024-02-25 13:28:27.795337: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 7.41517865s
Trying algorithm eng28{k2=3,k3=0} for conv (f32[1,1,46080]{2,1,0}, u8[0]{0}) custom-call(f32[1,2048,48127]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...
2024-02-25 13:28:28.795536: E external/xla/xla/service/slow_operation_alarm.cc:65] Trying algorithm eng0{} for conv (f32[1,1,46080]{2,1,0}, u8[0]{0}) custom-call(f32[1,2048,48127]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...
2024-02-25 13:28:29.575761: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 1.78033004s
Trying algorithm eng0{} for conv (f32[1,1,46080]{2,1,0}, u8[0]{0}) custom-call(f32[1,2048,48127]{2,1,0}, f32[1,2048,2048]{2,1,0}), window={size=2048}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0}} is taking a while...

Here's the script:

import jax
import jax.numpy as jnp
from jax import jit

from dm_aux.spectral import stft, istft, Pad

@jit
def modify_phase_and_reconstruct(audio: jnp.ndarray, phase_shift: jnp.ndarray) -> jnp.ndarray:
    length = audio.shape[-1]

    # Perform STFT
    stft_output = stft(audio, n_fft=2048, frame_length=2048, frame_step=1024, window_fn='hann', pad=Pad.BOTH)

    # Modify phase
    stft_modified = stft_output * jnp.exp(1j * phase_shift)

    # Perform iSTFT
    reconstructed_audio = istft(stft_modified, frame_length=2048, frame_step=1024, window_fn='hann', pad=Pad.START,
                                length=length)

    return reconstructed_audio

channels = 1  # Mono audio
sample_length = 44100  # 1 second of audio at a sample rate of 44100 Hz

# Simulate a batch of audio signals
simulated_audio = jax.random.normal(jax.random.PRNGKey(0), (channels, sample_length))

# Test the JIT-compiled function with a phase shift of 0.5 radians
phase_shift = jnp.array([0.5])
for _ in range(5):
    modified_audio_output = modify_phase_and_reconstruct(simulated_audio, phase_shift)
    print('tick')

# Output shape
print(modified_audio_output.shape)
# assert sample_length == modified_audio_output.shape[-1]

The warning doesn't appear if modify_phase_and_reconstruct just returns stft_modified early, so it seems to be related to the istft function.

I'm using WSL on Windows with Python 3.10 with an Nvidia RTX 2080.

Name: jax
Version: 0.4.24
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, opt-einsum, scipy
Required-by: chex, clu, flax, jaxloudnorm, optax, orbax-checkpoint
---
Name: jaxlib
Version: 0.4.24+cuda12.cudnn89
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/admin/.local/lib/python3.10/site-packages
Requires: ml-dtypes, numpy, scipy
Required-by: chex, clu, optax, orbax-checkpoint
DBraun commented 8 months ago

So as you may expect, this isn’t an issue for smaller frame_step and frame_length. But I think it’s related to https://github.com/google-deepmind/dm_aux/issues/2 For sufficiently large frame_length there has got to be a reason to use jax.scipy.signal.stft/istft