pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.01k stars 22.65k forks source link

RuntimeError in torch.istft with center=False: Window Overlap Add Issue #118507

Open happyTonakai opened 9 months ago

happyTonakai commented 9 months ago

🐛 Describe the bug

Issue Description:

I encountered an issue when using the torch.istft function in PyTorch. It seems to be related to the window overlap add and produces the following error message:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: istft(CPUComplexFloatType[513, 8], n_fft=1024, hop_length=256, win_length=1024, window=torch.FloatTensor{[1024]}, center=0, normalized=0, onesided=1, length=None, return_complex=0) window overlap add min: 1
[ CPUBoolType{} ]

Steps to Reproduce:

To reproduce the issue, you can use the following minimal working example:

import torch

nfft = 1024
hop = nfft // 4
T = 12
center = False
L = (T - 1) * hop if center else nfft + (T - 1) * hop
window = torch.hann_window(nfft)
x = torch.randn(L)

X = torch.stft(x, nfft, hop, nfft, window, center=center, onesided=True, return_complex=True)
xx = torch.istft(X, nfft, hop, nfft, window, center=center, onesided=True, return_complex=False)

print(torch.allclose(x, xx, atol=1e-6))

It works fine when center=False. It seems that someone else also had this problem.

Thank you for your assistance in resolving this issue.

Versions

PyTorch version: 2.1.2+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 Pro
GCC version: (MinGW-W64 x86_64-msvcrt-posix-seh, built by Brecht Sanders) 13.0.0 20221030 (experimental)
Clang version: Could not collect
CMake version: version 3.24.2
Libc version: N/A

Python version: 3.10.12 | packaged by Anaconda, Inc. | (main, Jul  5 2023, 19:01:18) [MSC v.1916 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.19043-SP0
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=3600
DeviceID=CPU0
Family=198
L2CacheSize=1024
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=3600
Name=Intel(R) Core(TM) i7-7700 CPU @ 3.60GHz
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] audio-data-pytorch==0.0.19
[pip3] numpy==1.25.2
[pip3] pytorch-lightning==2.0.7
[pip3] torch==2.1.2
[pip3] torch-pitch-shift==1.2.4
[pip3] torch-specinv==0.2.1
[pip3] torchaudio==2.1.2
[pip3] torchaudio-augmentations==0.2.4
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.1.0
[pip3] torchvision==0.16.2
[conda] audio-data-pytorch        0.0.19                   pypi_0    pypi
[conda] numpy                     1.25.2                   pypi_0    pypi
[conda] pytorch-lightning         2.0.7                    pypi_0    pypi
[conda] torch                     2.1.2                    pypi_0    pypi
[conda] torch-pitch-shift         1.2.4                    pypi_0    pypi
[conda] torch-specinv             0.2.1                    pypi_0    pypi
[conda] torchaudio                2.1.2                    pypi_0    pypi
[conda] torchaudio-augmentations  0.2.4                    pypi_0    pypi
[conda] torchinfo                 1.8.0                    pypi_0    pypi
[conda] torchmetrics              1.1.0                    pypi_0    pypi
[conda] torchvision               0.16.2                   pypi_0    pypi

cc @mruberry @peterbell10

happyTonakai commented 9 months ago

It seems that the reason is the first sample of Hann window is zero. When I use torch.ones(nfft) or torch.hamming_window(nfft), it will be ok.

Markovvn1w commented 5 months ago

This error means that the window coverage is not enough to fully recover the original signal.

Specifically in this example, when using center=False, the first element of the output waveform is calculated as res[0] = something / window[0]. However, window[0] is 0 when hamming_window is used, which results in an error when calculating res[0].

Source: https://github.com/pytorch/pytorch/blob/c404b2968cfe1163fff1802a6c1b71d5579a729b/aten/src/ATen/native/SpectralOps.cpp#L1173-L1181