Open LIUHAO121 opened 1 year ago
I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps.
First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0.
Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict.
Finally, I had to modify the dalle2_pytorch.py file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
Here is my script to get things running:
import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()
prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)
decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()
decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]
for k in decoder.clip.state_dict().keys():
decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
decoder.load_state_dict(decoder_model_state, strict=True)
dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()
images = dalle2(
['your prompt here'],
cond_scale = 2.
).cpu()
print(images.shape)
for img in images:
img = ToPILImage()(img)
img.show()
Let me know if you have any questions. Hope this helps.
@cest-andre Hi! I am seeing the same error as LIUHAO121 even with the above, and wondering if I might be using an incorrect set of files (prior_config.json, decoder_config.json, decoder_latest.pth, and prior_latest.pth).
I'm currently downloading these files from the huggingface repository -- is this correct?
Also would the json files need to be modified after downloading (e.g. to fix paths?)
@klei22 Apologies, I did rename the pth files so that may have added confusion.
For prior, I am using latest.pth and prior_config.json in huggingface's prior folder. For decoder, I am using latest.pth and decoder_config.json in decoder/v1.0.2 folder.
Note that in the code I posted, I modify the keys in decoder's state dictionary so that it recognizes the CLIP weights.
@cest-andre thank you very much for sharing. It doesn't seem to work on my end either. Can you share an example of the input text and the image you have as output ?
@tikitong Sure thing. A reminder that this fix requires downgrading dalle2-pytorch to version 1.1.0. Could you be more specific about what doesn't work on your end? Do you get an error or do you get bad image results?
Here are the results from the prompt 'a field of flowers':
@cest-andre @tikitong Hi, and many thanks for the reminder.
I also made the same mistake. I first fixed the dalle2-pytorch to version 1.1.0. I also double checked this.
>>> dalle2_pytorch.__version__ '1.1.0'
Then modified the code on line 2940 in the dalle2_pytorch.py file.
Finally, I downloaded the pre-train model and their config JSON, the same as you.
With input text
a field of flowers
No error reported
sampling loop time step: 100%|███████████████████████████████████| 64/64 [00:02<00:00, 25.80it/s] sampling loop time step: 100%|███████████████████████████████| 1000/1000 [03:37<00:00, 4.60it/s] 1it [03:41, 221.34s/it]: 100%|███████████████████████████████| 1000/1000 [03:37<00:00, 4.70it/s] (1, 3, 64, 64)
I got the result with a 64*64 2D image shown below:
I have no idea about the issue now~
@ZhangxinruBIT
You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.
for k in decoder.clip.state_dict().keys():
decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.
First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.
@cest-andre thanks too for your reply ! I have no error but the image does not match the input text. Here are the results from the prompt "a red car":
Do you have any idea where the problem might come from?
Here is the code I used, the weight files are the ones you have indicated before (For prior, latest.pth and _priorconfig.json in huggingface's prior folder. For decoder, latest.pth and _decoderconfig.json in decoder/v1.0.2 folder).
import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create()
prior_model_state = torch.load("weights/prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)
decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create()
decoder_model_state = torch.load("weights/decoder_latest.pth", map_location=torch.device('cpu'))["model"]
for k in decoder.clip.state_dict().keys():
decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
decoder.load_state_dict(decoder_model_state, strict=True)
dalle2 = DALLE2(prior=prior, decoder=decoder)
images = dalle2(
['a red car'],
cond_scale = 2.
).cpu()
for img in images:
img = ToPILImage()(img)
img.show()
Here is the configuration of my conda environnement :
# packages in environment at /Users/user/miniconda3/envs/clip:
#
# Name Version Build Channel
absl-py 1.4.0 pyhd8ed1ab_0 conda-forge
accelerate 0.17.1 pypi_0 pypi
aiohttp 3.8.3 py39h80987f9_0
aiosignal 1.3.1 pyhd8ed1ab_0 conda-forge
async-timeout 4.0.2 py39hca03da5_0
attrs 22.2.0 pyh71513ae_0 conda-forge
autopep8 1.6.0 pyhd3eb1b0_1
blas 1.0 openblas
blinker 1.5 pyhd8ed1ab_0 conda-forge
braceexpand 0.1.7 pypi_0 pypi
brotlipy 0.7.0 py39h1a28f6b_1002
bzip2 1.0.8 h620ffc9_4
c-ares 1.18.1 h3422bc3_0 conda-forge
ca-certificates 2023.01.10 hca03da5_0
cachetools 5.3.0 pyhd8ed1ab_0 conda-forge
certifi 2022.12.7 py39hca03da5_0
cffi 1.15.1 py39h80987f9_3
charset-normalizer 2.0.4 pyhd3eb1b0_0
click 8.1.3 unix_pyhd8ed1ab_2 conda-forge
clip 1.0 pypi_0 pypi
clip-anytorch 2.5.2 pypi_0 pypi
coca-pytorch 0.0.7 pypi_0 pypi
cryptography 39.0.1 py39h834c97f_0
dalle2-pytorch 1.1.0 pypi_0 pypi
einops 0.6.0 pypi_0 pypi
einops-exts 0.0.4 pypi_0 pypi
ema-pytorch 0.2.1 pypi_0 pypi
embedding-reader 1.5.0 pypi_0 pypi
ffmpeg 4.2.2 h04105a8_0
filelock 3.9.0 py39hca03da5_0
flit-core 3.6.0 pyhd3eb1b0_0
freetype 2.12.1 h1192e45_0
frozenlist 1.3.3 py39h80987f9_0
fsspec 2023.3.0 pypi_0 pypi
ftfy 6.1.1 pypi_0 pypi
gettext 0.21.0 h826f4ad_0
giflib 5.2.1 h80987f9_3
gmp 6.2.1 hc377ac9_3
gmpy2 2.1.2 py39h8c48613_0
gnutls 3.6.15 h887c41c_0
google-auth 2.16.2 pyh1a96a4e_0 conda-forge
google-auth-oauthlib 0.4.6 pyhd8ed1ab_0 conda-forge
grpcio 1.42.0 py39h95c9599_0
huggingface-hub 0.13.2 pypi_0 pypi
icu 68.1 hc377ac9_0
idna 3.4 py39hca03da5_0
importlib-metadata 6.1.0 pyha770c72_0 conda-forge
jinja2 3.1.2 py39hca03da5_0
jpeg 9e h80987f9_1
kornia 0.6.10 pypi_0 pypi
lame 3.100 h1a28f6b_0
lcms2 2.12 hba8e193_0
lerc 3.0 hc377ac9_0
libcxx 14.0.6 h848a8c0_0
libdeflate 1.17 h80987f9_0
libffi 3.4.2 hca03da5_6
libgfortran 5.0.0 11_3_0_hca03da5_28
libgfortran5 11.3.0 h009349e_28
libiconv 1.16 h1a28f6b_2
libidn2 2.3.1 h1a28f6b_0
libopenblas 0.3.21 h269037a_0
libopus 1.3 h1a28f6b_1
libpng 1.6.39 h80987f9_0
libprotobuf 3.20.3 h514c7bf_0
libtasn1 4.16.0 h1a28f6b_0
libtiff 4.5.0 h313beb8_2
libunistring 0.9.10 h1a28f6b_0
libvpx 1.10.0 hc377ac9_0
libwebp 1.2.4 ha3663a8_1
libwebp-base 1.2.4 h80987f9_1
libxml2 2.9.14 h8c5e841_0
llvm-openmp 14.0.6 hc6e5704_0
lpips 0.1.4 pypi_0 pypi
lz4-c 1.9.4 h313beb8_0
markdown 3.4.1 pyhd8ed1ab_0 conda-forge
markupsafe 2.1.1 py39h1a28f6b_0
mpc 1.1.0 h8c48613_1
mpfr 4.0.2 h695f6f0_1
mpmath 1.2.1 py39hca03da5_0
multidict 6.0.2 py39h1a28f6b_0
ncurses 6.4 h313beb8_0
nettle 3.7.3 h84b5d62_1
networkx 2.8.4 py39hca03da5_0
numpy 1.23.5 py39h1398885_0
numpy-base 1.23.5 py39h90707a3_0
oauthlib 3.2.2 pyhd8ed1ab_0 conda-forge
open-clip-torch 2.16.0 pypi_0 pypi
openh264 1.8.0 h98b2900_0
openssl 1.1.1t h1a28f6b_0
packaging 23.0 pypi_0 pypi
pandas 1.5.3 pypi_0 pypi
pillow 9.4.0 py39h313beb8_0
pip 23.0.1 py39hca03da5_0
protobuf 3.19.6 pypi_0 pypi
psutil 5.9.4 pypi_0 pypi
pyarrow 7.0.0 pypi_0 pypi
pyasn1 0.4.8 py_0 conda-forge
pyasn1-modules 0.2.7 py_0 conda-forge
pycodestyle 2.10.0 py39hca03da5_0
pycparser 2.21 pyhd3eb1b0_0
pydantic 1.10.6 pypi_0 pypi
pyjwt 2.6.0 pyhd8ed1ab_0 conda-forge
pyopenssl 23.0.0 py39hca03da5_0
pysocks 1.7.1 py39hca03da5_0
python 3.9.16 hc0d8a6c_2
python-dateutil 2.8.2 pypi_0 pypi
pytorch 2.0.0 py3.9_0 pytorch
pytorch-warmup 0.1.1 pypi_0 pypi
pytz 2022.7.1 pypi_0 pypi
pyu2f 0.1.5 pyhd8ed1ab_0 conda-forge
pyyaml 6.0 pypi_0 pypi
readline 8.2 h1a28f6b_0
regex 2022.10.31 pypi_0 pypi
requests 2.28.1 py39hca03da5_1
requests-oauthlib 1.3.1 pyhd8ed1ab_0 conda-forge
resize-right 0.0.2 pypi_0 pypi
rotary-embedding-torch 0.2.1 pypi_0 pypi
rsa 4.9 pyhd8ed1ab_0 conda-forge
scipy 1.10.1 pypi_0 pypi
sentencepiece 0.1.97 pypi_0 pypi
setuptools 65.6.3 py39hca03da5_0
six 1.16.0 pyh6c4a22f_0 conda-forge
sqlite 3.40.1 h7a7dc30_0
sympy 1.11.1 py39hca03da5_0
tensorboard 2.10.0 py39hca03da5_0
tensorboard-data-server 0.6.1 py39ha6e5c4f_0
tensorboard-plugin-wit 1.8.1 pyhd8ed1ab_0 conda-forge
timm 0.6.12 pypi_0 pypi
tk 8.6.12 hb8d0fd4_0
tokenizers 0.13.2 pypi_0 pypi
toml 0.10.2 pyhd3eb1b0_0
torch-fidelity 0.3.0 pypi_0 pypi
torch-tb-profiler 0.4.1 pypi_0 pypi
torchaudio 2.0.0 py39_cpu pytorch
torchmetrics 0.11.4 pypi_0 pypi
torchvision 0.15.0 py39_cpu pytorch
tqdm 4.65.0 pypi_0 pypi
transformers 4.27.0 pypi_0 pypi
typing_extensions 4.4.0 py39hca03da5_0
tzdata 2022g h04d1e81_0
urllib3 1.26.14 py39hca03da5_0
vector-quantize-pytorch 1.1.2 pypi_0 pypi
wcwidth 0.2.6 pypi_0 pypi
webdataset 0.2.43 pypi_0 pypi
werkzeug 2.2.3 pyhd8ed1ab_0 conda-forge
wheel 0.38.4 py39hca03da5_0
x-clip 0.12.1 pypi_0 pypi
x264 1!152.20180806 h1a28f6b_0
xz 5.2.10 h80987f9_1
yarl 1.8.1 py39h1a28f6b_0
zipp 3.15.0 pyhd8ed1ab_0 conda-forge
zlib 1.2.13 h5a0b063_0
zstd 1.5.2 h8574219_0
@tikitong I think you've got it working. The fact that you have a clean image that's at least car related makes me think it's working properly. The imperfect results are more of a function of the limitations of the model rather than any coding mistakes. I've also received some "off" results.
The model is non-deterministic, so you can run it multiple times and see if you get better images. But I think you're good to go.
@cest-andre thanks again for your time !
@cest-andre Hi, and many thanks for the reminder. When I used the same code as yours, I got some error:
import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
prior_config = TrainDiffusionPriorConfig.from_json_path(r"C:\Users\user\prior_config.json").prior
prior = prior_config.create()
prior_model_state = torch.load(r"C:\Users\user\prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)
decoder_config = TrainDecoderConfig.from_json_path(r"C:\Users\user\decoder_config.json").decoder
decoder = decoder_config.create()
decoder_model_state = torch.load(r"C:\Users\user\decoder_latest.pth", map_location=torch.device('cpu'))["model"]
for k in decoder.clip.state_dict().keys():
decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
decoder.load_state_dict(decoder_model_state, strict=True)
dalle2 = DALLE2(prior=prior, decoder=decoder)
images = dalle2(
['a red car'],
cond_scale = 2.
).cpu()
for img in images:
img = ToPILImage()(img)
img.show()
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 24
20 decoder.load_state_dict(decoder_model_state, strict=True)
22 dalle2 = DALLE2(prior=prior, decoder=decoder)
---> 24 images = dalle2(
25 ['a red car'],
26 cond_scale = 2.
27 ).cpu()
29 for img in images:
30 img = ToPILImage()(img)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
1047 # If we don't have any hooks, we want to skip the rest of the logic in
1048 # this function, and just call forward.
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
25 @functools.wraps(func)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:95, in eval_decorator.<locals>.inner(model, *args, **kwargs)
93 was_training = model.training
94 model.eval()
---> 95 out = fn(model, *args, **kwargs)
96 model.train(was_training)
97 return out
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:2937, in DALLE2.forward(self, text, cond_scale, prior_cond_scale, return_pil_images)
2934 text = [text] if not isinstance(text, (list, tuple)) else text
2935 text = tokenizer.tokenize(text).to(device)
-> 2937 image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
2939 text_cond = text if self.decoder_need_text_cond else None
2940 images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
25 @functools.wraps(func)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:95, in eval_decorator.<locals>.inner(model, *args, **kwargs)
93 was_training = model.training
94 model.eval()
---> 95 out = fn(model, *args, **kwargs)
96 model.train(was_training)
97 return out
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:1212, in DiffusionPrior.sample(self, text, num_samples_per_batch, cond_scale, timesteps)
1209 if self.condition_on_text_encodings:
1210 text_cond = {**text_cond, 'text_encodings': text_encodings}
-> 1212 image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
1214 # retrieve original unscaled image embed
1216 image_embeds /= self.image_embed_scale
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
25 @functools.wraps(func)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:1153, in DiffusionPrior.p_sample_loop(self, timesteps, *args, **kwargs)
1150 if not is_ddim:
1151 return self.p_sample_loop_ddpm(*args, **kwargs)
-> 1153 return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
25 @functools.wraps(func)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:1116, in DiffusionPrior.p_sample_loop_ddim(self, shape, text_cond, timesteps, eta, cond_scale)
1112 alpha_next = alphas[time_next]
1114 time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
-> 1116 pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
1118 if self.predict_x_start:
1119 x_start = pred
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:873, in DiffusionPriorNetwork.forward_with_cond_scale(self, cond_scale, *args, **kwargs)
867 def forward_with_cond_scale(
868 self,
869 *args,
870 cond_scale = 1.,
871 **kwargs
872 ):
--> 873 logits = self.forward(*args, **kwargs)
875 if cond_scale == 1:
876 return logits
File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:922, in DiffusionPriorNetwork.forward(self, image_embed, diffusion_timesteps, text_embed, text_encodings, cond_drop_prob)
918 mask = F.pad(mask, (0, remainder), value = False)
920 null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
--> 922 text_encodings = torch.where(
923 rearrange(mask, 'b n -> b n 1').clone(),
924 text_encodings,
925 null_text_embeds
926 )
928 # classifier free guidance
930 keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)
**RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cuda:0 and cpu respectively**
Can you help me to figure it out?
@kdavidlp123
You're not using the exact same code as you're trying to set your prior and decoder to your cpu. I've only ran everything on GPU so not sure what the exact problem is. Maybe you need to also set the DALLE2 object to cpu as well.
@cest-andre Thank you for your reply. After I put them all on cpu, it worked. And also I would like to ask, how to generate a larger picture? Does it depend on model or what?
@kdavidlp123 Open the decoder_config json and change "image_sizes". I was having trouble setting it to a higher value. The images would look very oddly distorted. I don't know much about the inner workings of DALLE so I suppose it cannot generalize to other dimensions without retraining. Not sure though.
Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth"
?
I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior
and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2
However, it shows error in
prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)
if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,
RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".
@ZhangxinruBIT
You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.
for k in decoder.clip.state_dict().keys(): decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.
First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.
This worked for me, thank you!
Hi, @tikitong sorry to bother you, where did you import
"weights/prior_config.json" and "prior_latest.pth"
? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error inprior.load_state_dict(prior_model_state, strict=True) decoder.load_state_dict(decoder_model_state, strict=True)
if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,
RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".
Hi, I solved this problem by deleting the folder named dalle2_pytorch
cuz I found when it exits, the version of dalle2_pytorch
I use will be 1.14.0, not 1.1.0.
Maybe you can try it also?
Hi, @tikitong sorry to bother you, where did you import
"weights/prior_config.json" and "prior_latest.pth"
? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error inprior.load_state_dict(prior_model_state, strict=True) decoder.load_state_dict(decoder_model_state, strict=True)
if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,
RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".
Hi, I solved this problem by deleting the folder named
dalle2_pytorch
cuz I found when it exits, the version ofdalle2_pytorch
I use will be 1.14.0, not 1.1.0. Maybe you can try it also?
I can run by referring to your answer. Many thanks!
@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help.
@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help.
The bottom of my first comment above has the code.
https://github.com/lucidrains/DALLE2-pytorch/issues/282#issuecomment-1468429675
@cest-andre I'm getting the following error while running the mentioned code:
Traceback (most recent call last):
File "/home/ubuntu/dalle2/main.py", line 4, in @root_validator
with pre=False (the default) you MUST specify skip_on_failure=True
. Note that @root_validator
is deprecated and should be replaced with @model_validator
.
For further information visit https://errors.pydantic.dev/2.0.3/u/root-validator-pre-skip
Solution: I have installed pydantic==1.10.6
I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps.
First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0.
Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict.
Finally, I had to modify the dalle2_pytorch.py file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:
images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)
Here is my script to get things running:
import torch from torchvision.transforms import ToPILImage from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2 from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior prior = prior_config.create().cuda() prior_model_state = torch.load("weights/prior_latest.pth") prior.load_state_dict(prior_model_state, strict=True) decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder decoder = decoder_config.create().cuda() decoder_model_state = torch.load("weights/decoder_latest.pth")["model"] for k in decoder.clip.state_dict().keys(): decoder_model_state["clip." + k] = decoder.clip.state_dict()[k] decoder.load_state_dict(decoder_model_state, strict=True) dalle2 = DALLE2(prior=prior, decoder=decoder).cuda() images = dalle2( ['your prompt here'], cond_scale = 2. ).cpu() print(images.shape) for img in images: img = ToPILImage()(img) img.show()
Let me know if you have any questions. Hope this helps.
Hi, i got an error from the decoder, I supouse is the dalle-pytorch version, does it actually works with the current file in the repository? decoder.load_state_dict(decoder_model_state, strict=True) _IncompatibleKeys(missing_keys=['unets.0.ups.0.3.net.0.weight', 'unets.0.ups.0.3.net.0.bias', 'unets.0.ups.1.3.net.0.weight', 'unets.0.ups.1.3.net.0.bias', 'unets.0.ups.2.3.net.0.weight', 'unets.0.ups.2.3.net.0.bias', 'unets.0.ups.3.3.net.0.weight', 'unets.0.ups.3.3.net.0.bias'], unexpected_keys=['unets.0.ups.0.3.weight', 'unets.0.ups.0.3.bias', 'unets.0.ups.1.3.weight', 'unets.0.ups.1.3.bias', 'unets.0.ups.2.3.weight', 'unets.0.ups.2.3.bias', 'unets.0.ups.3.3.weight', 'unets.0.ups.3.3.bias'])
This is my code for generating images, but the generated images are very blurry. Prior model: https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior Decoder model: https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 `import torch from torchvision.transforms import ToPILImage from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2 from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig
prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior prior = prior_config.create().cuda()
prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)
decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder decoder = decoder_config.create().cuda()
decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]
for k in decoder.clip.state_dict().keys(): decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]
decoder.load_state_dict(decoder_model_state, strict=True)
dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()
images = dalle2( text = ['A high quality photo of Times Square'], cond_scale = 2. )
for img in images: img.save("out.jpg") `
What is the reason for this error
Hi, @tikitong sorry to bother you, where did you import ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in
"weights/prior_config.json" and "prior_latest.pth"
prior.load_state_dict(prior_model_state, strict=True) decoder.load_state_dict(decoder_model_state, strict=True)
if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,
RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".
Hi, I solved this problem by deleting the folder named cuz I found when it exits, the version of I use will be 1.14.0, not 1.1.0. Maybe you can try it also?
dalle2_pytorch``dalle2_pytorch
where is the cuz folder,bro?
Hi, @tikitong sorry to bother you, where did you import ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in
"weights/prior_config.json" and "prior_latest.pth"
prior.load_state_dict(prior_model_state, strict=True) decoder.load_state_dict(decoder_model_state, strict=True)
if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,
RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".
Hi, I solved this problem by deleting the folder named cuz I found when it exits, the version of I use will be 1.14.0, not 1.1.0. Maybe you can try it also?
dalle2_pytorch
dalle2_pytorch ``where is the cuz folder,bro?
I have the same problem, and I change the version to 1.1.0, there is also some error.
this is my code for generate image,but the generated img is random。 prior model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/prior/best.pth decoder model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/decoder/1.5B/latest.pth