yl4579 / StyleTTS

Official Implementation of StyleTTS
MIT License
396 stars 64 forks source link

crashes during training #31

Closed ppisljar closed 1 year ago

ppisljar commented 1 year ago

after starting training i am getting the following error, sometimes right away, sometimes after a few steps

./aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [652,0,0], thread: [124,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [652,0,0], thread: [125,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                      
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [652,0,0], thread: [126,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                     
../aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [652,0,0], thread: [127,0,0] Assertion `srcIndex < srcSelectDimSize` failed.                                                       
Traceback (most recent call last):                                                                                                                                                                                                                                    |
  File "/home/tts/StyleTTS/train_first.py", line 393, in <module>                       
    main()                                                                                                                                                                                                     
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1130, in __call__                                                                                                                          
    return self.main(*args, **kwargs)                                                                                                                                                                           
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1055, in main                                                                                                                              
    rv = self.invoke(ctx)                                                                                                                                                                                       
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 1404, in invoke                                                                                                                            
    return ctx.invoke(self.callback, **ctx.params)                                                                                                                                                              
  File "/opt/conda/lib/python3.10/site-packages/click/core.py", line 760, in invoke                                                                                                                             
    return __callback(*args, **kwargs)                                                                                                                                                                          
  File "/home/tts/StyleTTS/train_first.py", line 149, in main                                                                                                                                                   
    ppgs, s2s_pred, s2s_attn_feat = model.text_aligner(mels, mask, texts)                                                                                                                                       
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                                           
    return forward_call(*args, **kwargs)                                                                                                                                                                        
  File "/home/tts/StyleTTS/Utils/ASR/models.py", line 45, in forward                                                                                                                                            
    _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)                                                                                                                                  
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                                           
    return forward_call(*args, **kwargs)                                                                                                                                                                        
  File "/home/tts/StyleTTS/Utils/ASR/models.py", line 130, in forward                                                                                                                                           
    print(f"... {text_input} {decoder_inputs.size(1)}")                                                                                                                                                         
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 873, in __format__                                                                                                                      
    return object.__format__(self, format_spec)                                                                                                                                                                 
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 426, in __repr__                                                                                                                        
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)                                                                                                                                        
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py", line 636, in _str                                                                                                                        
    return _str_intern(self, tensor_contents=tensor_contents)                                                                                                                                                   
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py", line 567, in _str_intern                                                                                                                 
    tensor_str = _tensor_str(self, indent)                                                                                                                                                                      
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py", line 327, in _tensor_str                                                                                                                 
    formatter = _Formatter(get_summarized_data(self) if summarize else self)                                                                                                                                    
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor_str.py", line 111, in __init__                                                                                                                    
    value_str = "{}".format(value)                                                                                                                                                                  
  File "/opt/conda/lib/python3.10/site-packages/torch/_tensor.py", line 872, in __format__                                                                                                                     
    return self.item().__format__(fo
ppisljar commented 1 year ago

running this on a6000 with 46gb ram

yl4579 commented 1 year ago

I think it is some index out-of-range problem. Are you sure n_token is set correctly as the your number of tokens in your training data?

ppisljar commented 1 year ago

i updated the meldataset.py and added 7 new characters that my dataset contains to the list of chars, i also updated n_tokens to 185 in config.yml but i am still getting the same error. i see the n_tokens setting in the config for ASR model as well. will i need to retrain the ASR model ?

yl4579 commented 1 year ago

Yes, they must match in terms of the n_tokens. If you change that hyperparameter, you also have to re-train the ASR model. I will try to put them together in the StyleTTS 2 repo so you don't have to do these steps separately.