explainingai-code / StableDiffusion-PyTorch

This repo implements a Stable Diffusion model in PyTorch with all the essential components.
122 stars 25 forks source link

It wuold be greatly appreciated if model ckpts could be provided #5

Closed CatLoves closed 6 months ago

CatLoves commented 6 months ago

Hello @explainingai-code ! I have read your code carefully and can train unconditional and text-conditioned celebHQ model now, but my GPU is only one V100 card and it's expected to train for ~110 hours to get one result, which is very time consuming. Renting a GPU cluster is also very expensive and time-consuming. If you have already trained the model, could you provide download links for model ckpts for convenience? I think this would be very helpful for people to quickly use your codebase. Moreover, thank you again for your great codebase and I think it wuold be greatly appreciated if model ckpts could be provided.

Sincerely, CatLoves

explainingai-code commented 6 months ago

Hello @CatLoves, I am really sorry but I dont have the checkpoints. I dont really train the models to convergence(to avoid cost), just to the point I can get some decent result, so they would anyways not be something that give the best results.

The main purpose of these codebases is to attempt to reach a point where the code is minimal and understandable for people coming from my videos(and not really best results). The official repos or libraries from huggingface would do a much better job if generating best images is your objective.

Having said all that, using single instance of V100, I was able to train the model using 15-20 mins per epoch(with latents saved). But based on your training time estimate, it seems its taking much longer for you. Are you using the save_latent config ? This precomputes the latents and hence in each step, you no longer need to do the vae inference call during ldm training. This should reduce the training time significantly.

CatLoves commented 6 months ago

Hi @explainingai-code: Thank you for your detailed and helpful response! I managed to enable save_latent config by the following steps:

  1. modify the logic code below: image
  2. modify the training params: image
  3. run tools/infer_vqvae.py to get saved_latents.pkl
  4. rerun tools/sample_ddpm_text_cond.py For step 1, I think it's maybe a little bug to be fixed. Have a nice day!
explainingai-code commented 6 months ago

Perfect. Just to add, the step 1 is not really an issue, cause later code takes care of it. https://github.com/explainingai-code/StableDiffusion-PyTorch/blob/main/dataset/celeb_dataset.py#L42-L50 . So as long as latents exist for ALL the images and the dataset class initialization receives use_latents=True(which it does through the training script), latents should be loaded otherwise vae inference will done.