can not generate normal image with pretrained model #282

Open LIUHAO121 opened 1 year ago

LIUHAO121 commented 1 year ago

this is my code for generate image,but the generated img is random。 prior model: decoder model:

import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch import Unet, Decoder,DALLE2

prior_network = DiffusionPriorNetwork(
    final_proj= True,
    rotary_emb= True

diffusion_prior = DiffusionPrior(
    # sample_timesteps = 64,



unet = Unet(
    **{"dim": 320,
    "cond_dim": 512,
    "image_embed_dim": 768,
    "text_embed_dim": 768,
    "cond_on_text_encodings": True,
    "channels": 3,
    "dim_mults": [1, 2, 3, 4],
    "num_resnet_blocks": 4,
    "attn_heads": 8,
    "attn_dim_head": 64,
    "sparse_attn": True,
    "memory_efficient": True,
    "self_attn": [False, True, True, True]}

decoder = Decoder(
    unet = unet,
    timesteps = 1000,
    image_sizes = [64],
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale = 2., # classifier free guidance strength (> 1 would strengthen the condition)
for img in images:"out.jpg")
截屏2023-03-08 18 53 58
cest-andre commented 1 year ago

I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps.

First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0.

Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict.

Finally, I had to modify the file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:

images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

Here is my script to get things running:

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.


for img in images:
    img = ToPILImage()(img)

Let me know if you have any questions. Hope this helps.

klei22 commented 1 year ago

@cest-andre Hi! I am seeing the same error as LIUHAO121 even with the above, and wondering if I might be using an incorrect set of files (prior_config.json, decoder_config.json, decoder_latest.pth, and prior_latest.pth).

I'm currently downloading these files from the huggingface repository -- is this correct?

Also would the json files need to be modified after downloading (e.g. to fix paths?)

cest-andre commented 1 year ago

@klei22 Apologies, I did rename the pth files so that may have added confusion.

For prior, I am using latest.pth and prior_config.json in huggingface's prior folder. For decoder, I am using latest.pth and decoder_config.json in decoder/v1.0.2 folder.

Note that in the code I posted, I modify the keys in decoder's state dictionary so that it recognizes the CLIP weights.

tikitong commented 1 year ago

@cest-andre thank you very much for sharing. It doesn't seem to work on my end either. Can you share an example of the input text and the image you have as output ?

cest-andre commented 1 year ago

@tikitong Sure thing. A reminder that this fix requires downgrading dalle2-pytorch to version 1.1.0. Could you be more specific about what doesn't work on your end? Do you get an error or do you get bad image results?

Here are the results from the prompt 'a field of flowers':


ZhangxinruBIT commented 1 year ago

@cest-andre @tikitong Hi, and many thanks for the reminder.

I also made the same mistake. I first fixed the dalle2-pytorch to version 1.1.0. I also double checked this. >>> dalle2_pytorch.__version__ '1.1.0'

Then modified the code on line 2940 in the file.

Finally, I downloaded the pre-train model and their config JSON, the same as you.

With input text

a field of flowers

No error reported

sampling loop time step: 100%|███████████████████████████████████| 64/64 [00:02<00:00, 25.80it/s] sampling loop time step: 100%|███████████████████████████████| 1000/1000 [03:37<00:00, 4.60it/s] 1it [03:41, 221.34s/it]: 100%|███████████████████████████████| 1000/1000 [03:37<00:00, 4.70it/s] (1, 3, 64, 64)

I got the result with a 64*64 2D image shown below:


I have no idea about the issue now~

cest-andre commented 1 year ago


You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.

First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.

tikitong commented 1 year ago

@cest-andre thanks too for your reply ! I have no error but the image does not match the input text. Here are the results from the prompt "a red car": img2

Do you have any idea where the problem might come from?

Here is the code I used, the weight files are the ones you have indicated before (For prior, latest.pth and _priorconfig.json in huggingface's prior folder. For decoder, latest.pth and _decoderconfig.json in decoder/v1.0.2 folder).

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create()

prior_model_state = torch.load("weights/prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create()

decoder_model_state = torch.load("weights/decoder_latest.pth", map_location=torch.device('cpu'))["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder)

images = dalle2(
    ['a red car'],
    cond_scale = 2.

for img in images:
    img = ToPILImage()(img)

Here is the configuration of my conda environnement :

cest-andre commented 1 year ago

@tikitong I think you've got it working. The fact that you have a clean image that's at least car related makes me think it's working properly. The imperfect results are more of a function of the limitations of the model rather than any coding mistakes. I've also received some "off" results.

The model is non-deterministic, so you can run it multiple times and see if you get better images. But I think you're good to go.

tikitong commented 1 year ago

@cest-andre thanks again for your time !

kdavidlp123 commented 1 year ago

@cest-andre Hi, and many thanks for the reminder. When I used the same code as yours, I got some error:

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path(r"C:\Users\user\prior_config.json").prior
prior = prior_config.create()

prior_model_state = torch.load(r"C:\Users\user\prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path(r"C:\Users\user\decoder_config.json").decoder
decoder = decoder_config.create()

decoder_model_state = torch.load(r"C:\Users\user\decoder_latest.pth", map_location=torch.device('cpu'))["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder)

images = dalle2(
    ['a red car'],
    cond_scale = 2.

for img in images:
    img = ToPILImage()(img)
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 24
     20 decoder.load_state_dict(decoder_model_state, strict=True)
     22 dalle2 = DALLE2(prior=prior, decoder=decoder)
---> 24 images = dalle2(
     25     ['a red car'],
     26     cond_scale = 2.
     27 ).cpu()
     29 for img in images:
     30     img = ToPILImage()(img)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\nn\modules\, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in eval_decorator.<locals>.inner(model, *args, **kwargs)
     93 was_training =
     94 model.eval()
---> 95 out = fn(model, *args, **kwargs)
     96 model.train(was_training)
     97 return out

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in DALLE2.forward(self, text, cond_scale, prior_cond_scale, return_pil_images)
   2934     text = [text] if not isinstance(text, (list, tuple)) else text
   2935     text = tokenizer.tokenize(text).to(device)
-> 2937 image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
   2939 text_cond = text if self.decoder_need_text_cond else None
   2940 images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in eval_decorator.<locals>.inner(model, *args, **kwargs)
     93 was_training =
     94 model.eval()
---> 95 out = fn(model, *args, **kwargs)
     96 model.train(was_training)
     97 return out

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in DiffusionPrior.sample(self, text, num_samples_per_batch, cond_scale, timesteps)
   1209 if self.condition_on_text_encodings:
   1210     text_cond = {**text_cond, 'text_encodings': text_encodings}
-> 1212 image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
   1214 # retrieve original unscaled image embed
   1216 image_embeds /= self.image_embed_scale

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in DiffusionPrior.p_sample_loop(self, timesteps, *args, **kwargs)
   1150 if not is_ddim:
   1151     return self.p_sample_loop_ddpm(*args, **kwargs)
-> 1153 return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in DiffusionPrior.p_sample_loop_ddim(self, shape, text_cond, timesteps, eta, cond_scale)
   1112 alpha_next = alphas[time_next]
   1114 time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
-> 1116 pred =, time_cond, cond_scale = cond_scale, **text_cond)
   1118 if self.predict_x_start:
   1119     x_start = pred

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in DiffusionPriorNetwork.forward_with_cond_scale(self, cond_scale, *args, **kwargs)
    867 def forward_with_cond_scale(
    868     self,
    869     *args,
    870     cond_scale = 1.,
    871     **kwargs
    872 ):
--> 873     logits = self.forward(*args, **kwargs)
    875     if cond_scale == 1:
    876         return logits

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\, in DiffusionPriorNetwork.forward(self, image_embed, diffusion_timesteps, text_embed, text_encodings, cond_drop_prob)
    918     mask = F.pad(mask, (0, remainder), value = False)
    920 null_text_embeds =
--> 922 text_encodings = torch.where(
    923     rearrange(mask, 'b n -> b n 1').clone(),
    924     text_encodings,
    925     null_text_embeds
    926 )
    928 # classifier free guidance
    930 keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)

**RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cuda:0 and cpu respectively**

Can you help me to figure it out?

cest-andre commented 1 year ago


You're not using the exact same code as you're trying to set your prior and decoder to your cpu. I've only ran everything on GPU so not sure what the exact problem is. Maybe you need to also set the DALLE2 object to cpu as well.

kdavidlp123 commented 1 year ago

@cest-andre Thank you for your reply. After I put them all on cpu, it worked. And also I would like to ask, how to generate a larger picture? Does it depend on model or what?

cest-andre commented 1 year ago

@kdavidlp123 Open the decoder_config json and change "image_sizes". I was having trouble setting it to a higher value. The images would look very oddly distorted. I don't know much about the inner workings of DALLE so I suppose it cannot generalize to other dimensions without retraining. Not sure though.

4daJKong commented 1 year ago

Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth" ? I tried to download them from and However, it shows error in

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

cezeilo commented 1 year ago


You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.

First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.

This worked for me, thank you!

YANDaoyu commented 1 year ago

Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth" ? I tried to download them from and However, it shows error in

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named dalle2_pytorch cuz I found when it exits, the version of dalle2_pytorch I use will be 1.14.0, not 1.1.0. Maybe you can try it also?

fido20160817 commented 1 year ago

Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth" ? I tried to download them from and However, it shows error in

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named dalle2_pytorch cuz I found when it exits, the version of dalle2_pytorch I use will be 1.14.0, not 1.1.0. Maybe you can try it also?

I can run by referring to your answer. Many thanks!

JasbirCodeSpace commented 1 year ago

@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help.

cest-andre commented 1 year ago

@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help.

The bottom of my first comment above has the code.

JasbirCodeSpace commented 1 year ago

@cest-andre I'm getting the following error while running the mentioned code:

Traceback (most recent call last): File "/home/ubuntu/dalle2/", line 4, in from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/dalle2_pytorch/", line 34, in class TrainSplitConfig(BaseModel): File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/dalle2_pytorch/", line 40, in TrainSplitConfig def validate_all(cls, fields): File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/pydantic/deprecated/", line 222, in root_validator return root_validator()(*__args) # type: ignore File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/pydantic/deprecated/", line 228, in root_validator raise PydanticUserError( pydantic.errors.PydanticUserError: If you use @root_validator with pre=False (the default) you MUST specify skip_on_failure=True. Note that @root_validator is deprecated and should be replaced with @model_validator.

For further information visit

Solution: I have installed pydantic==1.10.6

esteban-rs commented 1 year ago

I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps.

First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0.

Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict.

Finally, I had to modify the file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:

images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

Here is my script to get things running:

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.


for img in images:
    img = ToPILImage()(img)

Let me know if you have any questions. Hope this helps.

Hi, i got an error from the decoder, I supouse is the dalle-pytorch version, does it actually works with the current file in the repository? decoder.load_state_dict(decoder_model_state, strict=True) _IncompatibleKeys(missing_keys=['', '', '', '', '', '', '', ''], unexpected_keys=['', '', '', '', '', '', '', ''])

hanghaju commented 11 months ago

This is my code for generating images, but the generated images are very blurry. Prior model: Decoder model: `import torch from torchvision.transforms import ToPILImage from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2 from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")

prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys(): decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2( text = ['A high quality photo of Times Square'], cond_scale = 2. )

for img in images:"out.jpg") ` 35d819b69a321f62df3c05e9a152c06

ctxya1207 commented 9 months ago

028ff7b7398f31cc93171564a3c8140 What is the reason for this error

ALLIZZWELL123 commented 7 months ago

Hi, @tikitong sorry to bother you, where did you import ? I tried to download them from and However, it shows error in"weights/prior_config.json" and "prior_latest.pth"

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named cuz I found when it exits, the version of I use will be 1.14.0, not 1.1.0. Maybe you can try it also?dalle2_pytorch``dalle2_pytorch

where is the cuz folder,bro?

thePOET8 commented 5 months ago

Hi, @tikitong sorry to bother you, where did you import ? I tried to download them from and However, it shows error in"weights/prior_config.json" and "prior_latest.pth"

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named cuz I found when it exits, the version of I use will be 1.14.0, not 1.1.0. Maybe you can try it also?dalle2_pytorchdalle2_pytorch ``

where is the cuz folder,bro?

I have the same problem, and I change the version to 1.1.0, there is also some error.