mozilla / TTS

:robot: :speech_balloon: Deep learning for Text to Speech (Discussion forum:
Mozilla Public License 2.0
9.31k stars 1.24k forks source link

torch.stft fails with "Expected all tensors to be on the same device" #619

Closed gerazov closed 3 years ago

gerazov commented 3 years ago

When running vocoder training it fails with (whole Traceback at the end):

File ".../lib/python3.6/site-packages/torch/", line 516, in stft
    normalized, onesided, return_complex)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I've located the problem in the TorchSTFT class in TTS/vocoder/ line 7:

class TorchSTFT():
    def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
        """ Torch based STFT operation """
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = getattr(torch, window)(win_length)

The problem is with self.window which doesn't get transferred to CUDA when the loss function gets transferred via criterion_gen.cuda() in the TTS/bin/ line 536.

I've managed to solve this by subclassing torch.nn.Module and listing self.window as a paramter. This way the .cuda() will transfer the window to cuda and stft will work:

class TorchSTFT(nn.Module): 
     def __init__(self, n_fft, hop_length, win_length, window='hann_window'):
         """ Torch based STFT operation """
         super(TorchSTFT, self).__init__() 
         self.n_fft = n_fft
         self.hop_length = hop_length
         self.win_length = win_length
         self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)

Here's the PR #620

In it I've also added the parameter return_complex=False in torch.stft because of the reported future change in the default behaviour.

I'm wandering if torch just didn't report this in the previous version, and it just went ahead transffering to and from the cpu/gpu. In that sense this could speed up vocoder model training :thinking:

This is my environment:

# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main              
absl-py                   0.11.0                   pypi_0    pypi      
astroid                   2.4.2                    pypi_0    pypi        
astunparse                1.6.3                    pypi_0    pypi    
attrdict                  2.0.1                    pypi_0    pypi         
attrs                     20.3.0                   pypi_0    pypi         
audioread                 2.1.9                    pypi_0    pypi     
blas                      1.0                         mkl                       
bokeh                     1.4.0                    pypi_0    pypi       
ca-certificates           2020.12.8            h06a4308_0      
cachetools                4.2.0                    pypi_0    pypi     
cardboardlint             1.3.0                    pypi_0    pypi    
certifi                   2020.12.5        py36h06a4308_0       
cffi                      1.14.4                   pypi_0    pypi          
chardet                   4.0.0                    pypi_0    pypi       
click                     7.1.2                    pypi_0    pypi          
clldutils                 3.6.0                    pypi_0    pypi         
colorlog                  4.6.2                    pypi_0    pypi
csvw                      1.9.0                    pypi_0    pypi 
cycler                    0.10.0                   pypi_0    pypi 
cython                    0.29.21          py36h2531618_0
dataclasses               0.8                      pypi_0    pypi 
decorator                 4.4.2                    pypi_0    pypi  
filelock                  3.0.12                   pypi_0    pypi    
flask                     1.1.2                    pypi_0    pypi      
gast                      0.3.3                    pypi_0    pypi      
google-auth               1.24.0                   pypi_0    pypi
google-auth-oauthlib      0.4.2                    pypi_0    pypi 
google-pasta              0.2.0                    pypi_0    pypi     
grpcio                    1.34.0                   pypi_0    pypi         
h5py                      2.10.0                   pypi_0    pypi         
idna                      2.10                     pypi_0    pypi           
importlib-metadata        3.3.0                    pypi_0    pypi  
inflect                   5.0.2                    pypi_0    pypi           
intel-openmp              2020.2                      254               
isodate                   0.6.0                    pypi_0    pypi         
isort                     4.3.21                   pypi_0    pypi           
itsdangerous              1.1.0                    pypi_0    pypi     
jinja2                    2.11.2                   pypi_0    pypi          
joblib                    1.0.0                    pypi_0    pypi           
keras-preprocessing       1.1.2                    pypi_0    pypi
kiwisolver                1.3.1                    pypi_0    pypi        
lazy-object-proxy         1.4.3                    pypi_0    pypi   
ld_impl_linux-64          2.33.1               h53a641e_7       
libedit                   3.1.20191231         h14c3975_1         
libffi                    3.3                  he6710b0_2                   
libgcc-ng                 9.1.0                hdf63c60_0              
librosa                   0.7.2                    pypi_0    pypi          
libstdcxx-ng              9.1.0                hdf63c60_0             
llvmlite                  0.31.0                   pypi_0    pypi          
markdown                  3.3.3                    pypi_0    pypi     
markupsafe                1.1.1                    pypi_0    pypi     
matplotlib                3.3.3                    pypi_0    pypi        
mccabe                    0.6.1                    pypi_0    pypi       
mkl                       2020.2                      256                     
mkl-service               2.3.0            py36he8ac12f_0         
mkl_fft                   1.2.0            py36h23d657b_0           
mkl_random                1.1.1            py36h0573a6f_0      
ncurses                   6.2                  he6710b0_1  
nose                      1.3.7                    pypi_0    pypi 
numba                     0.48.0                   pypi_0    pypi
numpy                     1.18.5                   pypi_0    pypi
oauthlib                  3.1.0                    pypi_0    pypi
openssl                   1.1.1i               h27cfd23_0  
opt-einsum                3.3.0                    pypi_0    pypi
packaging                 20.8                     pypi_0    pypi
phonemizer                2.2.2                    pypi_0    pypi
pillow                    8.1.0                    pypi_0    pypi
pip                       20.3.3           py36h06a4308_0  
protobuf                  3.14.0                   pypi_0    pypi
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pycparser                 2.20                     pypi_0    pypi
pylint                    2.5.3                    pypi_0    pypi
pyparsing                 2.4.7                    pypi_0    pypi
pysbd                     0.3.3                    pypi_0    pypi
pysocks                   1.7.1                    pypi_0    pypi
python                    3.6.12               hcff3b4d_2  
python-dateutil           2.8.1                    pypi_0    pypi
pyworld                   0.2.12                   pypi_0    pypi
pyyaml                    5.3.1                    pypi_0    pypi
readline                  8.0                  h7b6447c_0  
regex                     2020.11.13               pypi_0    pypi
requests                  2.25.1                   pypi_0    pypi
requests-oauthlib         1.3.0                    pypi_0    pypi
resampy                   0.2.2                    pypi_0    pypi
rfc3986                   1.4.0                    pypi_0    pypi
rsa                       4.6                      pypi_0    pypi
scikit-learn              0.24.0                   pypi_0    pypi
scipy                     1.5.4                    pypi_0    pypi
segments                  2.2.0                    pypi_0    pypi
setuptools                51.0.0           py36h06a4308_2  
six                       1.15.0           py36h06a4308_0  
soundfile                 0.10.3.post1             pypi_0    pypi
sqlite                    3.33.0               h62c20be_0  
tabulate                  0.8.7                    pypi_0    pypi
tensorboard               2.4.0                    pypi_0    pypi
tensorboard-plugin-wit    1.7.0                    pypi_0    pypi
tensorboardx              2.1                      pypi_0    pypi
tensorflow                2.3.1                    pypi_0    pypi
tensorflow-estimator      2.3.0                    pypi_0    pypi
termcolor                 1.1.0                    pypi_0    pypi
threadpoolctl             2.1.0                    pypi_0    pypi
tk                        8.6.10               hbc83047_0  
toml                      0.10.2                   pypi_0    pypi
torch                     1.7.1                    pypi_0    pypi
tornado                   6.1                      pypi_0    pypi
tqdm                      4.55.1                   pypi_0    pypi
tts                       0.0.6+9cf474a            pypi_0    pypi
typed-ast                 1.4.2                    pypi_0    pypi
typing-extensions                  pypi_0    pypi
umap-learn                0.4.6                    pypi_0    pypi
unidecode                 0.04.20                  pypi_0    pypi
uritemplate               3.0.1                    pypi_0    pypi
urllib3                   1.26.2                   pypi_0    pypi
werkzeug                  1.0.1                    pypi_0    pypi
wheel                     0.36.2             pyhd3eb1b0_0  
wrapt                     1.12.1                   pypi_0    pypi
xz                        5.2.5                h7b6447c_0  
zipp                      3.4.0                    pypi_0    pypi
zlib                      1.2.11               h7b6447c_3 

And here's the whole Traceback:

$ python TTS/bin/ --config_path TTS/vocoder/configs/my_parallel_wavegan_config.json                                                                                                                                                    

 > Using CUDA:  True                                                                                                                                                                                                                                         
 > Number of GPUs:  1                                                                                                                                                                                                                                        
 > Git Hash: 7beaacc                                                                                                                                                                                                                                         
 > Experiment folder: /home/vibe/tts/mozilla/Models/LJSpeech/pwgan-January-16-2021_11+48AM-7beaacc                                                                                                                                                           
 > Loading wavs from: /home/vibe/tts/databases/LJSpeech-1.1/wavs/                                                                                                                                                                                            
 > Setting up Audio Processor...
 | > sample_rate:22050
 | > resample:False
 | > num_mels:80
 | > min_level_db:-100
 | > frame_shift_ms:None
 | > frame_length_ms:None
 | > ref_level_db:0
 | > fft_size:1024
 | > power:None
 | > preemphasis:0.0
 | > griffin_lim_iters:None
 | > signal_norm:True
 | > symmetric_norm:True
 | > mel_fmin:50.0
 | > mel_fmax:7600.0
 | > spec_gain:1.0
 | > stft_pad_mode:reflect
 | > max_norm:4.0
 | > clip_norm:True
 | > do_trim_silence:True
 | > trim_db:60
 | > do_sound_norm:False
 | > stats_path:/home/vibe/tts/databases/LJSpeech-1.1/scale_stats.npy
 | > hop_length:256
 | > win_length:1024
 > Generator Model: parallel_wavegan_generator
 > Discriminator Model: parallel_wavegan_discriminator
 > Generator has 1320442 parameters
 > Discriminator has 99842 parameters

 > EPOCH: 0/10000

 > TRAINING (2021-01-16 11:48:49) 
/home/vibe/miniconda3/envs/tts/lib/python3.6/site-packages/torch/ UserWarning: stft will require the return_complex parameter be explicitly  specified in a future PyTorch release. Use return_complex=False  to preserve the current behavior or return_complex=True to return  a complex output. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:653.)
  normalized, onesided, return_complex)
 ! Run is removed from /home/vibe/tts/mozilla/Models/LJSpeech/pwgan-January-16-2021_11+48AM-7beaacc
Traceback (most recent call last):
  File "TTS/bin/", line 654, in <module>
  File "TTS/bin/", line 559, in main
  File "TTS/bin/", line 152, in train
    feats_real, y_hat_sub, y_G_sub)
  File "/home/vibe/miniconda3/envs/tts/lib/python3.6/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vibe/tts/mozilla/TTS_gerazov/TTS/vocoder/layers/", line 233, in forward
    stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat.squeeze(1), y.squeeze(1))
  File "/home/vibe/miniconda3/envs/tts/lib/python3.6/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vibe/tts/mozilla/TTS_gerazov/TTS/vocoder/layers/", line 70, in forward
    lm, lsc = f(y_hat, y)
  File "/home/vibe/miniconda3/envs/tts/lib/python3.6/site-packages/torch/nn/modules/", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/vibe/tts/mozilla/TTS_gerazov/TTS/vocoder/layers/", line 46, in forward
    y_hat_M = self.stft(y_hat)
  File "/home/vibe/tts/mozilla/TTS_gerazov/TTS/vocoder/layers/", line 25, in __call__
  File "/home/vibe/miniconda3/envs/tts/lib/python3.6/site-packages/torch/", line 516, in stft
    normalized, onesided, return_complex)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
omkarade commented 2 years ago

how solve this---RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!