Closed GeGehub closed 1 year ago
Following this post.
@hamzafar You can hit the Subscribe button in the right to follow this thread.
Hi everybody. We just added a brand new DiT PyTorch training script (train.py
) to the repo. Note that you'll need to update to the latest version of the repo to use it. Sorry for the delay!
The script is not super well-tested currently; we only tried training a 256x256 DiT-XL/2 model from scratch for 90K steps on an A100 node (8x GPUs), but the loss curve looks correct (at least up to that point), and FID-50K at 50K steps is very similar to the JAX version's. If you encounter any bugs, please open a new issue and I'll try my best to take a look.
@wpeebles Thanks for sharing your incredible work.
Update: since the training script was released, I've trained a few XL/2 and B/4 models. In all experiments the PyTorch-trained models perform very closely compared to the JAX ones (sometimes better actually). I added a bunch of info to the README here. Just make sure you update to the latest version of the repo.
Great job; I wonder when you will release the training code and scripts.
Thanks