keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 117 forks source link

Add rank>2 support for `stft` #825

Closed james77777778 closed 1 year ago

james77777778 commented 1 year ago

After experimenting with stft and istft, I have discovered that the input rank limitation imposed by torch.stft is inconvenient.

When containing multiple channels in single waveform, the input shape might be (batch_size, num_channels, num_samples) which is not supported by torch.stft. However, this format works well with other backends.

To address this, the PR has added reshape logic for stft function in the torch backend. The rank limitation of ops.stft has now been removed.

Furthermore, I have found that torch.istft works under specific condition which is common in audio processing.

The current implementation can be summarized in the following table:

backend stft istft
numpy scipy.signal.stft scipy.signal.istft
jax jax.scipy.signal.stft jax.scipy.signal.istft
tf tf.signal.stft tf.signal.inverse_stft
torch torch.stft torch.istft & custom implementation