yl4579 / StyleTTS2

StyleTTS 2: Towards Human-Level Text-to-Speech through Style Diffusion and Adversarial Training with Large Speech Language Models
MIT License
4.84k stars 401 forks source link

Joint training is failing with Assertion error #262

Open nvadigauvce opened 3 months ago

nvadigauvce commented 3 months ago

Hi All, I am training to fine-tune our custom dataset with LibriTTS base model. I am able to fine-tuning with my dataset till initial and diffusion training, but when Joint training starts, I am facing below error. I tried all the solutions mentioned in this blog like reducing the audio length (1-30 sec), text length less than 200 tokens, reducing slmadv_params (min_len: 100, max_len), but not helping much. This issue happening in both single and 4 GPU machine.

Any pointer to address below issue will be helpful.

specifically getting error while computing bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) when call for slm_out = slmadv happens. Let me know if any solution for this.

Errors: ../aten/src/ATen/native/cuda/Indexing.cu:1289: indexSelectLargeIndex: block: [269,0,0], thread: [31,0,0] Assertion srcIndex < srcSelectDimSize failed.

Traceback (most recent call last): File "/home/user/ASR/StyleTTS2/train_finetune.py", line 707, in main() File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/click/core.py", line 1157, in call return self.main(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/click/core.py", line 1078, in main rv = self.invoke(ctx) ^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/click/core.py", line 1434, in invoke return ctx.invoke(self.callback, ctx.params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/click/core.py", line 783, in invoke return __callback(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/ASR/StyleTTS2/train_finetune.py", line 487, in main slm_out = slmadv(i, ^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/ASR/StyleTTS2/Modules/slmadv.py", line 26, in forward bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int()) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/parallel/data_parallel.py", line 183, in forward return self.module(inputs[0], module_kwargs[0]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/ASR/StyleTTS2/Utils/PLBERT/util.py", line 9, in forward outputs = super().forward(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py", line 719, in forward encoder_outputs = self.encoder( ^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py", line 468, in forward layer_group_output = self.albert_layer_groups[group_idx]( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl

File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py", line 383, in forward attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/mnt/disk2/arveti.manjunath/miniconda3/envs/stts2/lib/python3.12/site-packages/transformers/models/albert/modeling_albert.py", line 353, in forward context_layer = context_layer.transpose(2, 1).flatten(2) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

martinambrus commented 1 month ago

It seems you're running the accelerate version of the finetuning script in an environment where the error message doesn't help much. Perhaps try to run a non-accelerated version via python train_finetune.py --config_path ./Configs/config_ft.yml and see what error you'll get. That should show you a better error message as to what's really going on with your training.

FormMe commented 1 month ago

I got the same issue while train second. I used non-accelerated version too.

I found out that the issue is in OOD texts. Seems that you use (as well as me) OOD texts that doesnt fit to plbert model. Some of the phonems not presented in plbert. I solved it by turning off usage of OOD texts (use_ind=True). btw, I should add another OOD texts that fit to plbert later

Anyway, it looks wierd, because OOD text use text_cleaner too