lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

fix prior saving / add load method #44

Closed rom1504 closed 2 years ago

TheoCoombes commented 2 years ago

The following code worked to load the Prior from Huggingface, just using the stock default hparams:

from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork
import torch

device = "cuda:0"
state_dict = torch.load("model.pth")

prior_network = DiffusionPriorNetwork( 
    dim=768,
    depth=6, 
    dim_head=64, 
    heads=8
).to(device)

diffusion_prior = DiffusionPrior( 
    net=prior_network,
    clip=None, 
    image_embed_dim=768, 
    timesteps=100,
    cond_drop_prob=0.2, 
    loss_type="l2", 
    condition_on_text_encodings=False
).to(device)

diffusion_prior.load_state_dict(state_dict)
rom1504 commented 2 years ago

Here's an example how to do this right: https://github.com/lucidrains/DALLE-pytorch/blob/main/train_dalle.py#L535 https://github.com/lucidrains/DALLE-pytorch/blob/main/train_dalle.py#L301 https://github.com/lucidrains/DALLE-pytorch/blob/main/train_dalle.py#L427

lucidrains commented 2 years ago

addressed here i believe https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/train.py#L52