shivammehta25 / Matcha-TTS

[ICASSP 2024] 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching
https://shivammehta25.github.io/Matcha-TTS/
MIT License
747 stars 95 forks source link

Error when trying to train on a multispeaker dataset #71

Closed gardaa closed 7 months ago

gardaa commented 7 months ago

Hi! I am trying to train the Matcha-TTS model on my own dataset in a low-resource language. Therefore, I have to use some ASR data to test it, even though I know it is not the best type of data for TTS training. Because it is a multispeaker dataset, I am changing the n_spks variable in dataset.yaml to 29, which is the total amount of speakers in the training set (in the validation set there are 45 different speakers).

From research, I believe it might have something to do with changing the n_spks from 1 to 29 and that messes with some indexing or boundaries that have been set, but I am not sure. Also, I managed to run it without errors when I had n_spks=1 (even though the results were not great due to noisy dataset).

When I run the script to train the model on a GPU, it gives me hundreds of lines with this error:

../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [0,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [0,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [0,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [0,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [0,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1290: indexSelectLargeIndex: block: [0,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

Followed by:

Traceback (most recent call last):
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/baselightningmodule.py", line 128, in validation_step
    loss_dict = self.get_losses(batch)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/baselightningmodule.py", line 61, in get_losses
    dur_loss, prior_loss, diff_loss = self(
                                      ^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/matcha_tts.py", line 176, in forward
    mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/components/text_encoder.py", line 397, in forward
    x = self.emb(x) * math.sqrt(self.n_channels)
        ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/train.py", line 112, in main
    metric_dict, _ = train(cfg)
                     ^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/utils/utils.py", line 86, in wrap
    raise ex
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/utils/utils.py", line 76, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
                               ^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/train.py", line 79, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 68, in _call_and_handle_interrupt
    trainer._teardown()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1010, in _teardown
    self.strategy.teardown()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 537, in teardown
    self.lightning_module.cpu()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 82, in cpu
    return super().cpu()
           ^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 960, in cpu
    return self._apply(lambda t: t.cpu())
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 960, in <lambda>
    return self._apply(lambda t: t.cpu())
                                 ^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Traceback (most recent call last):
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 412, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/baselightningmodule.py", line 128, in validation_step
    loss_dict = self.get_losses(batch)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/baselightningmodule.py", line 61, in get_losses
    dur_loss, prior_loss, diff_loss = self(
                                      ^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/matcha_tts.py", line 176, in forward
    mu_x, logw, x_mask = self.encoder(x, x_lengths, spks)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/models/components/text_encoder.py", line 397, in forward
    x = self.emb(x) * math.sqrt(self.n_channels)
        ~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~~~
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/train.py", line 112, in main
    metric_dict, _ = train(cfg)
                     ^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/utils/utils.py", line 86, in wrap
    raise ex
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/utils/utils.py", line 76, in wrap
    metric_dict, object_dict = task_func(cfg=cfg)
                               ^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/MatchaTTS_Norwegian_Custom/matcha/train.py", line 79, in train
    trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 68, in _call_and_handle_interrupt
    trainer._teardown()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1010, in _teardown
    self.strategy.teardown()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 537, in teardown
    self.lightning_module.cpu()
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/lightning/fabric/utilities/device_dtype_mixin.py", line 82, in cpu
    return super().cpu()
           ^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 960, in cpu
    return self._apply(lambda t: t.cpu())
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 802, in _apply
    module._apply(fn)
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 825, in _apply
    param_applied = fn(param)
                    ^^^^^^^^^
  File "/global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 960, in <lambda>
    return self._apply(lambda t: t.cpu())
                                 ^^^^^^^
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x155552cf4d87 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x155552ca575f in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x155552dc58a8 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0xfa5656 (0x1555086d3656 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libtorch_cuda.so)
frame #4: <unknown function> + 0x543010 (0x1555516bf010 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x649bf (0x155552cd99bf in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #6: c10::TensorImpl::~TensorImpl() + 0x21b (0x155552cd2c8b in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #7: c10::TensorImpl::~TensorImpl() + 0x9 (0x155552cd2e39 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #8: <unknown function> + 0x80b718 (0x155551987718 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libtorch_python.so)
frame #9: THPVariable_subclass_dealloc(_object*) + 0x2f6 (0x155551987a96 in /global/D1/homes/gardaa/gards-py311-cu121-venv/lib/python3.11/site-packages/torch/lib/libtorch_python.so)
frame #10: python() [0x52d8be]
frame #11: python() [0x56e86e]
frame #12: python() [0x56d568]
frame #13: python() [0x56d5a9]
frame #14: python() [0x56d5a9]
frame #15: python() [0x56d5a9]
frame #16: python() [0x56d5a9]
frame #17: python() [0x56d5a9]
frame #18: python() [0x56d5a9]
frame #19: python() [0x57b43a]
frame #20: python() [0x57b4f9]
frame #21: python() [0x5b7c77]
frame #22: python() [0x523195]
frame #23: python() [0x60ee6e]

I am looking forward for the help with this issue. Thank you so much in advance!

shivammehta25 commented 7 months ago

This is a common pytorch issue when input to the embedding layer has larger numbers than the size of embedding layer. Could you please ensure that the length of the symbols = number in nn.Embedding().

gardaa commented 7 months ago

I have tried to change it in line 51 in the matcha_tts.py file (see the code below) to the length of the symbols (which is 216), but I can not get it to work. Which format should it have, and is it the correct file?

if n_spks > 1:
            self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)

Edit:

The error is now fixed. Since the values of the speaker IDs were not 1-n, but random numbers, due to the nature of the dataset, it gave an error. So I had to set the num_speakers to max value + 1.