lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

accelerate's wait_for_everyone hangs on the final step of training coarse/fine transformer #210

Closed LWprogramming closed 1 year ago

LWprogramming commented 1 year ago

Right before and after the wait_for_everyone call, I print out the device and step and find that the main GPU sometimes doesn't make it to after the wait_for_everyone call while the others do. I'm not sure why this is the case, but it ends up looking like this (after cleaning up the logging output-- arriving at 2 is right before the call and at 3 is right after the call).

499: device cuda:4 arrived at 2
499: device cuda:5 arrived at 2
499: device cuda:2 arrived at 2
499: device cuda:3 arrived at 2
499: device cuda:6 arrived at 2
499: device cuda:7 arrived at 2
499: device cuda:1 arrived at 2
498: device cuda:0 arrived 3
coarse 499: loss: 8.292579423141433e-07
499: device cuda:0 arrived at 2
499: device cuda:1 arrived 3
499: device cuda:3 arrived 3
499: device cuda:4 arrived 3
499: device cuda:6 arrived 3
499: device cuda:7 arrived 3
499: device cuda:5 arrived 3
499: device cuda:2 arrived 3
500: device cuda:4 arrived at 2
500: device cuda:3 arrived at 2
500: device cuda:6 arrived at 2
500: device cuda:2 arrived at 2
500: device cuda:1 arrived at 2
500: device cuda:7 arrived at 2
499: device cuda:0 arrived 3
500: device cuda:5 arrived at 2
coarse 500: loss: 8.264195798801666e-07
**500: device cuda:0 arrived at 2**
500: device cuda:3 arrived 3
500: device cuda:2 arrived 3
500: device cuda:7 arrived 3
500: device cuda:1 arrived 3
500: device cuda:4 arrived 3
500: device cuda:6 arrived 3
500: device cuda:5 arrived 3
**#cuda:0 never arrives at 3**

There are a couple of weird things to this:

I found a few issues about wait for everyone here and here but I don't think I've fully understood them. Another idea that I wonder about is, if the other devices go ahead after the wait, finish the full training run, and then main device goes ahead with reporting validation loss and saving checkpoint, could this result in weird behavior?

LWprogramming commented 1 year ago

oh dear, this is embarrassing... as soon as I got up and headed out I had an idea that I want to try that might fix this, Will let you know how it goes

LWprogramming commented 1 year ago

Never mind, it didn't work. I wondered if, just from initializing the trainers, I was inadvertently doing weird things with accelerators, but it still hangs the same as before.

lucidrains commented 1 year ago

ah, can't help you until i get back into this next month

do you want to try raising the issue at huggingface? they have more experts on the training side

LWprogramming commented 1 year ago

Ok, I've found a workaround but I have no idea why it works. Basically if I insert an extra wait_for_everyone call after the training loop, the training does finish properly but main process still sticks around forever waiting for something, not sure what. Luckily it should be possible to train things (as I did in #214 ), just by running 3 separate jobs to train each transformer separately, then a 4th to sample using the checkpoints.

TLDR: Things are definitely broken but not enough to stop me from doing training work, so I'll close this issue but leave the huggingface one up. Hopefully they come up with suggestions for what could be going on