lucidrains / e2-tts-pytorch

Implementation of E2-TTS, "Embarrassingly Easy Fully Non-Autoregressive Zero-Shot TTS", in Pytorch
MIT License
228 stars 21 forks source link

multi-gpu training #19

Closed eschmidbauer closed 1 month ago

eschmidbauer commented 1 month ago

Is it possible to train across multiple GPUs?

lucidrains commented 1 month ago

@eschmidbauer it uses huggingface accelerate under the hood, so just follow the accelerate multi-gpu instructions

lucidrains commented 1 month ago

@eschmidbauer are you seeing something on 1 gpu?

eschmidbauer commented 1 month ago

i get an error with: accelerate launch ./train_e2.py

[rank0]: Traceback (most recent call last):
[rank0]:   File "/e2-tts-api/app/./train_e2.py", line 41, in <module>
[rank0]:     trainer.train(train_dataset, epochs, batch_size, num_workers, save_step=10)
[rank0]:   File "/e2-tts-api/app/e2_tts_pytorch/trainer.py", line 220, in train
[rank0]:     loss = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/e2-tts-api/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/e2-tts-api/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/e2-tts-api/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1589, in forward
[rank0]:     inputs, kwargs = self._pre_forward(*inputs, **kwargs)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/e2-tts-api/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1480, in _pre_forward
[rank0]:     if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
[rank0]: making sure all `forward` function outputs participate in calculating loss.
[rank0]: If you already have done the above, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).
[rank0]: Parameter indices which did not receive grad for rank 0: 1 14 27 40 53 66 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
[rank0]:  In addition, you can set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL to print out information about which particular parameters did not receive gradient on this rank as part of this error
skirdey commented 1 month ago

my best bet would be these parameters come from duration predictor if you have it defined

Coice commented 1 month ago

@eschmidbauer in the transformer, change:

skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj else None

to:

skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj and i>=depth//2 else None

eschmidbauer commented 1 month ago
diff --git a/e2_tts_pytorch/e2_tts.py b/e2_tts_pytorch/e2_tts.py
index 1490ffa..966bdf0 100644
--- a/e2_tts_pytorch/e2_tts.py
+++ b/e2_tts_pytorch/e2_tts.py
@@ -275,14 +275,14 @@ class Transformer(Module):
                 nn.SiLU()
             )

-        for _ in range(depth):
+        for i in range(depth):
             attn_norm = rmsnorm_klass(dim)
             attn = Attention(dim = dim, heads = heads, dim_head = dim_head, dropout = dropout, **attn_kwargs)

             ff_norm = rmsnorm_klass(dim)
             ff = FeedForward(dim = dim, glu = True, dropout = dropout, **ff_kwargs)

-            skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj else None
+            skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj and i>=depth//2 else None

             self.layers.append(ModuleList([
                 skip_proj,

same error with the above change

eschmidbauer commented 1 month ago

actually it is slightly different error - probably cuz i just did a git pull

[rank0]:   File "/e2-tts-api/app/e2_tts_pytorch/trainer.py", line 235, in train
[rank0]:     loss, cond, pred = self.model(mel_spec, text=text_inputs, lens=mel_lengths)
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

image

lucidrains commented 1 month ago

@eschmidbauer in the transformer, change:

skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj else None

to:

skip_proj = Linear(dim * 2, dim, bias = False) if needs_skip_proj and i>=depth//2 else None

ah thanks for catching this, put in the fix

@eschmidbauer could you try 0.2.7?

eschmidbauer commented 1 month ago

great, it's working now!

lucidrains commented 1 month ago

nice, that's what I like to hear

Coice commented 1 month ago

Just for some info, I tried training on a 12x4090 rig with 55k hours of audio and let it go for days and could never get the text embeddings to work. It wasn't this version of the code, it was based on voicebox, so hopefully your training bears fruit 🍒😬

lucidrains commented 1 month ago

@Coice yeah I've heard this from multiple independent researchers

on the other hand, I think msft and some other folks have gotten it working and even written papers with it, so maybe they found some bug but they never bothered with sending the fix upstream. but that's fine, as long as they publish, we all benefit from the knowledge