lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.97k stars 753 forks source link

Accelerate integration #95

Closed nateraw closed 2 years ago

nateraw commented 2 years ago

Continuing the conversation from #73 here to track progress of integrating accelerate with the trainer here.

CC @sgugger @muellerzr who may be able to help if you run into issues @lucidrains.

jacobwjs commented 2 years ago

Just to add a positive use case for Accelerate, I've been using it (DDP backend) since some of your earlier versions (~v0.06.0) to run distributed training. You shouldn't run into any major problems :)

lucidrains commented 2 years ago

@sgugger @muellerzr thank you both for the library

Sylvain, been following you since you were working with Jeremy at fast.ai. Thank you for all your contributions so far :pray:

lucidrains commented 2 years ago

@jacobwjs yes indeed, I'm integrating it over at https://github.com/lucidrains/denoising-diffusion-pytorch , as a practice run, and so far so good :smile:

lucidrains commented 2 years ago

working great! the only thing was i am saving all the state manually, and wasn't sure where the GradScaler was, but i found it on the accelerator instance

lucidrains commented 2 years ago

i also see that gradient accumulation is in the works, so i can pare down some more logic in the future (right now i'm doing it manually)

lucidrains commented 2 years ago

pretty certain Imagen will be multi-GPU trainable by end of the week now

muellerzr commented 2 years ago

@lucidrains is there a particular reason you can't use Accelerator.save_state/Accelerator.load_state for what you're trying to accomplish? It should save/load everything stored including the scaler states.

I've also debated on writing individual funcs for saving each of those items as well.

lucidrains commented 2 years ago

@muellerzr yup i noticed that! however, i was manually keeping track of the training step count, and had an exponentially moving averaged version of the model on the main process to boot

i can always wrap it in an object with the necessary methods and just register it with accelerate for checkpointing, but was feeling lazy

lucidrains commented 2 years ago

@muellerzr will definitely test out those features some time this week!

jacobwjs commented 2 years ago

Another thing to think about is how this scales out once large models are getting trained. Inference (and hosting) becomes quite an issue if you're training on big compute.

If integrating with Accelerate I think the sooner we get away from hand-rolled methods the sooner we can take advantage of some really cool features.

https://huggingface.co/docs/accelerate/big_modeling

I use the above to compute my embeddings on some rather large language models. Without the above the only other option is cpu (sloooooow) or getting dirty with hand-placed layouts on gpu.

lucidrains commented 2 years ago

@jacobwjs yup, i'm aware, though it is probably still a work in progress atm

ok, i squared away yet another blocker (dataloader management), so time to put my mind on the accelerate integration and get it done with! should be done by end of this weekend (software estimates, always multiply by 2 or 3 :laughing: )

lucidrains commented 2 years ago

one thing i wasn't sure of is, if i were to be saving the optimizer and schedule manually, would i need to unwrap those as well, similar to how the model is unwrapped using accelerator.unwrap_model?

muellerzr commented 2 years ago

You can just use Accelerate's save function (not accelerator) and pass in the optimizer and Scheduler.

See how we do save_state here:

https://github.com/huggingface/accelerate/blob/main/src/accelerate/checkpointing.py#L76-L89

save makes sure it only occurs on the main process

lucidrains commented 2 years ago

@muellerzr ah ok, but if i'd like to do it manually, i can safely call optimizer.state_dict() underneath the conditional accelerator.is_main_process ?

muellerzr commented 2 years ago

It'd be safer to do is_local_main_process since that also works in multi-node situations, but for XLA you have to save a little differently. So you should do something like:

from accelerate.utils import save

save(optimizer.state_dict(), "my_opt_state_dict.pkl")

Under the hood it performs:

if using_tpu():
  xm.save(myfile, myfname)
elif accelerator.is_local_main_process():
  torch.save(myfile, myfname)

This is what the save utility wraps around fully

lucidrains commented 2 years ago

ok good to know! I need to colocate all my models and optimizers and schedulers (one for each unet) in the same .pt file

I'll definitely use is_local_main_process then, thanks!

muellerzr commented 2 years ago

Sounds good! Also cleaned up that code snippet to show that save will both wrap the is_local_main_process and the TPU check in case it's useful to know 😄

Also feel free to peruse this doc if you ever think "there's probably something special that needs to be done", we've likely written a utility func for it. (And if not that means we should probably!)

AlvL1225 commented 2 years ago

Hi @lucidrains Can be wrapped by accelerate by now? I know Trainer may need more adaptive work.

lucidrains commented 2 years ago

@yli1994 still working on it within my mind. measure twice cut once kind of thing

jacobwjs commented 2 years ago

@yli1994 still working on it within my mind. measure twice cut once kind of thing

@lucidrains Any blockers I can help with on this?

lucidrains commented 2 years ago

@jacobwjs not for this, i need to go solo on this. the code will come out better that way

could definitely use some help on https://github.com/lucidrains/DALLE2-pytorch/pull/181 however! you can reach Romain, Aidan, Zion on the Laion discord at any time

jacobwjs commented 2 years ago

@jacobwjs not for this, i need to go solo on this. the code will come out better that way

could definitely use some help on lucidrains/DALLE2-pytorch#181 however! you can reach Romain, Aidan, Zion on the Laion discord at any time

ok looking forward to seeing what you piece together. happy to test when you're ready.

On the other point, my plate is full diving into some issues with Imagen and ElucidatedImagen (see great results btw!). There's something off in the contrast levels between the two, which I think is just a function of optimizing classifier free guidance levels, dynamic thresholding percentile and using the proper attention mechanism.

Another quick thing that popped up is in dynamic thresholding you're clamping s on min, and the pseudo in the paper bounds s to max. Not familiar with JAX, but assume the two are equivalent?

When I wrap the above up and the text/image non-leaky augmentations we discussed I'll jump over there and see if I can help out.

lucidrains commented 2 years ago

@jacobwjs yup, taking the max with respect to some value is the same as clamping with that value as minimum

ok sounds good! and yes, i keep hearing about how great elucidated ddpm is. really want to finish off this remaining piece and do some training of my own, perhaps on video. i expected no less of Tero Karras, who is a bonafide genius in DL research

and thanks! testing would definitely be most welcomed once i push the initial code out

jacobwjs commented 2 years ago

Thanks for the clarification. and yep couldn't agree more, been following his great work since progressive growing of gans came out.

video is on my roadmap as well. curious to see whats possible.

lucidrains commented 2 years ago

ok, it seems to be working, at least on my two dinky GTX 2080 Tis! https://github.com/lucidrains/imagen-pytorch#multi-gpu-preliminary will definitely be recommending accelerate to a lot of people 😄

nateraw commented 2 years ago

@lucidrains You have a script/snippet I can try on multi gpu machine to test before closing? :)

lucidrains commented 2 years ago

@nateraw yup! the one in the readme under the dataloaders section

if you plop that in a file, replace /path/to/training/images to a folder on your machine that has about 100 images, then accelerate config followed by accelerate config train.py should work :crossed_fingers:

nateraw commented 2 years ago

Worked for me on 2xA5000 🚀 Nice work

muellerzr commented 2 years ago

FYI @lucidrains with the release of 0.11 the gradient accumulation support is finally out 😄 https://github.com/huggingface/accelerate/releases/tag/v0.11.0