facebookresearch / DiT

Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Other
6.37k stars 570 forks source link

Training scripts #2

Closed GeGehub closed 1 year ago

GeGehub commented 1 year ago

Great job; I wonder when you will release the training code and scripts.

Thanks

hamzafar commented 1 year ago

Following this post.

phcerdan commented 1 year ago

@hamzafar You can hit the Subscribe button in the right to follow this thread.

image

wpeebles commented 1 year ago

Hi everybody. We just added a brand new DiT PyTorch training script (train.py) to the repo. Note that you'll need to update to the latest version of the repo to use it. Sorry for the delay!

The script is not super well-tested currently; we only tried training a 256x256 DiT-XL/2 model from scratch for 90K steps on an A100 node (8x GPUs), but the loss curve looks correct (at least up to that point), and FID-50K at 50K steps is very similar to the JAX version's. If you encounter any bugs, please open a new issue and I'll try my best to take a look.

YTEP-ZHI commented 1 year ago

@wpeebles Thanks for sharing your incredible work.

wpeebles commented 1 year ago

Update: since the training script was released, I've trained a few XL/2 and B/4 models. In all experiments the PyTorch-trained models perform very closely compared to the JAX ones (sometimes better actually). I added a bunch of info to the README here. Just make sure you update to the latest version of the repo.