lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.16k stars 1.09k forks source link

can not generate normal image with pretrained model #282

Open LIUHAO121 opened 1 year ago

LIUHAO121 commented 1 year ago

this is my code for generate image,but the generated img is random。 prior model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/prior/best.pth decoder model: https://huggingface.co/laion/DALLE2-PyTorch/blob/main/decoder/1.5B/latest.pth

import torch
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter
from dalle2_pytorch import Unet, Decoder,DALLE2

prior_network = DiffusionPriorNetwork(
    dim=768,
    depth=12,
    dim_head=64,
    heads=12,
    normformer=True,
    attn_dropout=5e-2,
    ff_dropout=5e-2,
    num_time_embeds=1,
    num_image_embeds=1,
    num_text_embeds=1,
    num_timesteps=1000,
    ff_mult=4,
    final_proj= True,
    rotary_emb= True
)

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=OpenAIClipAdapter("ViT-L/14"),
    image_embed_dim=768,
    timesteps=1000,
    # sample_timesteps = 64,
    cond_drop_prob=0.1,
    loss_type="l2",
    condition_on_text_encodings=True,

)

diffusion_prior.load_state_dict(torch.load("prior.pth",map_location=torch.device('cpu')),strict=False)

unet = Unet(
    **{"dim": 320,
    "cond_dim": 512,
    "image_embed_dim": 768,
    "text_embed_dim": 768,
    "cond_on_text_encodings": True,
    "channels": 3,
    "dim_mults": [1, 2, 3, 4],
    "num_resnet_blocks": 4,
    "attn_heads": 8,
    "attn_dim_head": 64,
    "sparse_attn": True,
    "memory_efficient": True,
    "self_attn": [False, True, True, True]}
)

decoder = Decoder(
    unet = unet,
    clip=OpenAIClipAdapter("ViT-L/14"),
    timesteps = 1000,
    image_sizes = [64],
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5,
    learned_variance=True
)
decoder.load_state_dict(torch.load("decoder.pth",map_location=torch.device('cpu')),strict=False)

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale = 2., # classifier free guidance strength (> 1 would strengthen the condition)
    return_pil_images=True,
)
for img in images:
    img.save("out.jpg")
截屏2023-03-08 18 53 58
cest-andre commented 1 year ago

I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps.

First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0.

Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict.

Finally, I had to modify the dalle2_pytorch.py file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:

images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

Here is my script to get things running:

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.
).cpu()

print(images.shape)

for img in images:
    img = ToPILImage()(img)
    img.show()

Let me know if you have any questions. Hope this helps.

klei22 commented 1 year ago

@cest-andre Hi! I am seeing the same error as LIUHAO121 even with the above, and wondering if I might be using an incorrect set of files (prior_config.json, decoder_config.json, decoder_latest.pth, and prior_latest.pth).

I'm currently downloading these files from the huggingface repository -- is this correct?

Also would the json files need to be modified after downloading (e.g. to fix paths?)

cest-andre commented 1 year ago

@klei22 Apologies, I did rename the pth files so that may have added confusion.

For prior, I am using latest.pth and prior_config.json in huggingface's prior folder. For decoder, I am using latest.pth and decoder_config.json in decoder/v1.0.2 folder.

Note that in the code I posted, I modify the keys in decoder's state dictionary so that it recognizes the CLIP weights.

tikitong commented 1 year ago

@cest-andre thank you very much for sharing. It doesn't seem to work on my end either. Can you share an example of the input text and the image you have as output ?

cest-andre commented 1 year ago

@tikitong Sure thing. A reminder that this fix requires downgrading dalle2-pytorch to version 1.1.0. Could you be more specific about what doesn't work on your end? Do you get an error or do you get bad image results?

Here are the results from the prompt 'a field of flowers':

field_of_flowers

ZhangxinruBIT commented 1 year ago

@cest-andre @tikitong Hi, and many thanks for the reminder.

I also made the same mistake. I first fixed the dalle2-pytorch to version 1.1.0. I also double checked this. >>> dalle2_pytorch.__version__ '1.1.0'

Then modified the code on line 2940 in the dalle2_pytorch.py file.

Finally, I downloaded the pre-train model and their config JSON, the same as you.

With input text

a field of flowers

No error reported

sampling loop time step: 100%|███████████████████████████████████| 64/64 [00:02<00:00, 25.80it/s] sampling loop time step: 100%|███████████████████████████████| 1000/1000 [03:37<00:00, 4.60it/s] 1it [03:41, 221.34s/it]: 100%|███████████████████████████████| 1000/1000 [03:37<00:00, 4.70it/s] (1, 3, 64, 64)

I got the result with a 64*64 2D image shown below:

image

I have no idea about the issue now~

cest-andre commented 1 year ago

@ZhangxinruBIT

You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.

First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.

tikitong commented 1 year ago

@cest-andre thanks too for your reply ! I have no error but the image does not match the input text. Here are the results from the prompt "a red car": img2

Do you have any idea where the problem might come from?

Here is the code I used, the weight files are the ones you have indicated before (For prior, latest.pth and _priorconfig.json in huggingface's prior folder. For decoder, latest.pth and _decoderconfig.json in decoder/v1.0.2 folder).

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create()

prior_model_state = torch.load("weights/prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create()

decoder_model_state = torch.load("weights/decoder_latest.pth", map_location=torch.device('cpu'))["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder)

images = dalle2(
    ['a red car'],
    cond_scale = 2.
).cpu()

for img in images:
    img = ToPILImage()(img)
    img.show()

Here is the configuration of my conda environnement :

# packages in environment at /Users/user/miniconda3/envs/clip:
#
# Name                    Version                   Build  Channel
absl-py                   1.4.0              pyhd8ed1ab_0    conda-forge
accelerate                0.17.1                   pypi_0    pypi
aiohttp                   3.8.3            py39h80987f9_0  
aiosignal                 1.3.1              pyhd8ed1ab_0    conda-forge
async-timeout             4.0.2            py39hca03da5_0  
attrs                     22.2.0             pyh71513ae_0    conda-forge
autopep8                  1.6.0              pyhd3eb1b0_1  
blas                      1.0                    openblas  
blinker                   1.5                pyhd8ed1ab_0    conda-forge
braceexpand               0.1.7                    pypi_0    pypi
brotlipy                  0.7.0           py39h1a28f6b_1002  
bzip2                     1.0.8                h620ffc9_4  
c-ares                    1.18.1               h3422bc3_0    conda-forge
ca-certificates           2023.01.10           hca03da5_0  
cachetools                5.3.0              pyhd8ed1ab_0    conda-forge
certifi                   2022.12.7        py39hca03da5_0  
cffi                      1.15.1           py39h80987f9_3  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
click                     8.1.3           unix_pyhd8ed1ab_2    conda-forge
clip                      1.0                      pypi_0    pypi
clip-anytorch             2.5.2                    pypi_0    pypi
coca-pytorch              0.0.7                    pypi_0    pypi
cryptography              39.0.1           py39h834c97f_0  
dalle2-pytorch            1.1.0                    pypi_0    pypi
einops                    0.6.0                    pypi_0    pypi
einops-exts               0.0.4                    pypi_0    pypi
ema-pytorch               0.2.1                    pypi_0    pypi
embedding-reader          1.5.0                    pypi_0    pypi
ffmpeg                    4.2.2                h04105a8_0  
filelock                  3.9.0            py39hca03da5_0  
flit-core                 3.6.0              pyhd3eb1b0_0  
freetype                  2.12.1               h1192e45_0  
frozenlist                1.3.3            py39h80987f9_0  
fsspec                    2023.3.0                 pypi_0    pypi
ftfy                      6.1.1                    pypi_0    pypi
gettext                   0.21.0               h826f4ad_0  
giflib                    5.2.1                h80987f9_3  
gmp                       6.2.1                hc377ac9_3  
gmpy2                     2.1.2            py39h8c48613_0  
gnutls                    3.6.15               h887c41c_0  
google-auth               2.16.2             pyh1a96a4e_0    conda-forge
google-auth-oauthlib      0.4.6              pyhd8ed1ab_0    conda-forge
grpcio                    1.42.0           py39h95c9599_0  
huggingface-hub           0.13.2                   pypi_0    pypi
icu                       68.1                 hc377ac9_0  
idna                      3.4              py39hca03da5_0  
importlib-metadata        6.1.0              pyha770c72_0    conda-forge
jinja2                    3.1.2            py39hca03da5_0  
jpeg                      9e                   h80987f9_1  
kornia                    0.6.10                   pypi_0    pypi
lame                      3.100                h1a28f6b_0  
lcms2                     2.12                 hba8e193_0  
lerc                      3.0                  hc377ac9_0  
libcxx                    14.0.6               h848a8c0_0  
libdeflate                1.17                 h80987f9_0  
libffi                    3.4.2                hca03da5_6  
libgfortran               5.0.0           11_3_0_hca03da5_28  
libgfortran5              11.3.0              h009349e_28  
libiconv                  1.16                 h1a28f6b_2  
libidn2                   2.3.1                h1a28f6b_0  
libopenblas               0.3.21               h269037a_0  
libopus                   1.3                  h1a28f6b_1  
libpng                    1.6.39               h80987f9_0  
libprotobuf               3.20.3               h514c7bf_0  
libtasn1                  4.16.0               h1a28f6b_0  
libtiff                   4.5.0                h313beb8_2  
libunistring              0.9.10               h1a28f6b_0  
libvpx                    1.10.0               hc377ac9_0  
libwebp                   1.2.4                ha3663a8_1  
libwebp-base              1.2.4                h80987f9_1  
libxml2                   2.9.14               h8c5e841_0  
llvm-openmp               14.0.6               hc6e5704_0  
lpips                     0.1.4                    pypi_0    pypi
lz4-c                     1.9.4                h313beb8_0  
markdown                  3.4.1              pyhd8ed1ab_0    conda-forge
markupsafe                2.1.1            py39h1a28f6b_0  
mpc                       1.1.0                h8c48613_1  
mpfr                      4.0.2                h695f6f0_1  
mpmath                    1.2.1            py39hca03da5_0  
multidict                 6.0.2            py39h1a28f6b_0  
ncurses                   6.4                  h313beb8_0  
nettle                    3.7.3                h84b5d62_1  
networkx                  2.8.4            py39hca03da5_0  
numpy                     1.23.5           py39h1398885_0  
numpy-base                1.23.5           py39h90707a3_0  
oauthlib                  3.2.2              pyhd8ed1ab_0    conda-forge
open-clip-torch           2.16.0                   pypi_0    pypi
openh264                  1.8.0                h98b2900_0  
openssl                   1.1.1t               h1a28f6b_0  
packaging                 23.0                     pypi_0    pypi
pandas                    1.5.3                    pypi_0    pypi
pillow                    9.4.0            py39h313beb8_0  
pip                       23.0.1           py39hca03da5_0  
protobuf                  3.19.6                   pypi_0    pypi
psutil                    5.9.4                    pypi_0    pypi
pyarrow                   7.0.0                    pypi_0    pypi
pyasn1                    0.4.8                      py_0    conda-forge
pyasn1-modules            0.2.7                      py_0    conda-forge
pycodestyle               2.10.0           py39hca03da5_0  
pycparser                 2.21               pyhd3eb1b0_0  
pydantic                  1.10.6                   pypi_0    pypi
pyjwt                     2.6.0              pyhd8ed1ab_0    conda-forge
pyopenssl                 23.0.0           py39hca03da5_0  
pysocks                   1.7.1            py39hca03da5_0  
python                    3.9.16               hc0d8a6c_2  
python-dateutil           2.8.2                    pypi_0    pypi
pytorch                   2.0.0                   py3.9_0    pytorch
pytorch-warmup            0.1.1                    pypi_0    pypi
pytz                      2022.7.1                 pypi_0    pypi
pyu2f                     0.1.5              pyhd8ed1ab_0    conda-forge
pyyaml                    6.0                      pypi_0    pypi
readline                  8.2                  h1a28f6b_0  
regex                     2022.10.31               pypi_0    pypi
requests                  2.28.1           py39hca03da5_1  
requests-oauthlib         1.3.1              pyhd8ed1ab_0    conda-forge
resize-right              0.0.2                    pypi_0    pypi
rotary-embedding-torch    0.2.1                    pypi_0    pypi
rsa                       4.9                pyhd8ed1ab_0    conda-forge
scipy                     1.10.1                   pypi_0    pypi
sentencepiece             0.1.97                   pypi_0    pypi
setuptools                65.6.3           py39hca03da5_0  
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlite                    3.40.1               h7a7dc30_0  
sympy                     1.11.1           py39hca03da5_0  
tensorboard               2.10.0           py39hca03da5_0  
tensorboard-data-server   0.6.1            py39ha6e5c4f_0  
tensorboard-plugin-wit    1.8.1              pyhd8ed1ab_0    conda-forge
timm                      0.6.12                   pypi_0    pypi
tk                        8.6.12               hb8d0fd4_0  
tokenizers                0.13.2                   pypi_0    pypi
toml                      0.10.2             pyhd3eb1b0_0  
torch-fidelity            0.3.0                    pypi_0    pypi
torch-tb-profiler         0.4.1                    pypi_0    pypi
torchaudio                2.0.0                  py39_cpu    pytorch
torchmetrics              0.11.4                   pypi_0    pypi
torchvision               0.15.0                 py39_cpu    pytorch
tqdm                      4.65.0                   pypi_0    pypi
transformers              4.27.0                   pypi_0    pypi
typing_extensions         4.4.0            py39hca03da5_0  
tzdata                    2022g                h04d1e81_0  
urllib3                   1.26.14          py39hca03da5_0  
vector-quantize-pytorch   1.1.2                    pypi_0    pypi
wcwidth                   0.2.6                    pypi_0    pypi
webdataset                0.2.43                   pypi_0    pypi
werkzeug                  2.2.3              pyhd8ed1ab_0    conda-forge
wheel                     0.38.4           py39hca03da5_0  
x-clip                    0.12.1                   pypi_0    pypi
x264                      1!152.20180806       h1a28f6b_0  
xz                        5.2.10               h80987f9_1  
yarl                      1.8.1            py39h1a28f6b_0  
zipp                      3.15.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               h5a0b063_0  
zstd                      1.5.2                h8574219_0 
cest-andre commented 1 year ago

@tikitong I think you've got it working. The fact that you have a clean image that's at least car related makes me think it's working properly. The imperfect results are more of a function of the limitations of the model rather than any coding mistakes. I've also received some "off" results.

The model is non-deterministic, so you can run it multiple times and see if you get better images. But I think you're good to go.

tikitong commented 1 year ago

@cest-andre thanks again for your time !

kdavidlp123 commented 1 year ago

@cest-andre Hi, and many thanks for the reminder. When I used the same code as yours, I got some error:

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path(r"C:\Users\user\prior_config.json").prior
prior = prior_config.create()

prior_model_state = torch.load(r"C:\Users\user\prior_latest.pth", map_location=torch.device('cpu'))
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path(r"C:\Users\user\decoder_config.json").decoder
decoder = decoder_config.create()

decoder_model_state = torch.load(r"C:\Users\user\decoder_latest.pth", map_location=torch.device('cpu'))["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder)

images = dalle2(
    ['a red car'],
    cond_scale = 2.
).cpu()

for img in images:
    img = ToPILImage()(img)
    img.show()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[1], line 24
     20 decoder.load_state_dict(decoder_model_state, strict=True)
     22 dalle2 = DALLE2(prior=prior, decoder=decoder)
---> 24 images = dalle2(
     25     ['a red car'],
     26     cond_scale = 2.
     27 ).cpu()
     29 for img in images:
     30     img = ToPILImage()(img)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\nn\modules\module.py:1051, in Module._call_impl(self, *input, **kwargs)
   1047 # If we don't have any hooks, we want to skip the rest of the logic in
   1048 # this function, and just call forward.
   1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051     return forward_call(*input, **kwargs)
   1052 # Do not call functions when jit is used
   1053 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:95, in eval_decorator.<locals>.inner(model, *args, **kwargs)
     93 was_training = model.training
     94 model.eval()
---> 95 out = fn(model, *args, **kwargs)
     96 model.train(was_training)
     97 return out

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:2937, in DALLE2.forward(self, text, cond_scale, prior_cond_scale, return_pil_images)
   2934     text = [text] if not isinstance(text, (list, tuple)) else text
   2935     text = tokenizer.tokenize(text).to(device)
-> 2937 image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale)
   2939 text_cond = text if self.decoder_need_text_cond else None
   2940 images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:95, in eval_decorator.<locals>.inner(model, *args, **kwargs)
     93 was_training = model.training
     94 model.eval()
---> 95 out = fn(model, *args, **kwargs)
     96 model.train(was_training)
     97 return out

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:1212, in DiffusionPrior.sample(self, text, num_samples_per_batch, cond_scale, timesteps)
   1209 if self.condition_on_text_encodings:
   1210     text_cond = {**text_cond, 'text_encodings': text_encodings}
-> 1212 image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps)
   1214 # retrieve original unscaled image embed
   1216 image_embeds /= self.image_embed_scale

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:1153, in DiffusionPrior.p_sample_loop(self, timesteps, *args, **kwargs)
   1150 if not is_ddim:
   1151     return self.p_sample_loop_ddpm(*args, **kwargs)
-> 1153 return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\torch\autograd\grad_mode.py:28, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_context(*args, **kwargs):
     27     with self.__class__():
---> 28         return func(*args, **kwargs)

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:1116, in DiffusionPrior.p_sample_loop_ddim(self, shape, text_cond, timesteps, eta, cond_scale)
   1112 alpha_next = alphas[time_next]
   1114 time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
-> 1116 pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
   1118 if self.predict_x_start:
   1119     x_start = pred

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:873, in DiffusionPriorNetwork.forward_with_cond_scale(self, cond_scale, *args, **kwargs)
    867 def forward_with_cond_scale(
    868     self,
    869     *args,
    870     cond_scale = 1.,
    871     **kwargs
    872 ):
--> 873     logits = self.forward(*args, **kwargs)
    875     if cond_scale == 1:
    876         return logits

File ~\anaconda3\envs\dalle2_1.1.0\lib\site-packages\dalle2_pytorch\dalle2_pytorch.py:922, in DiffusionPriorNetwork.forward(self, image_embed, diffusion_timesteps, text_embed, text_encodings, cond_drop_prob)
    918     mask = F.pad(mask, (0, remainder), value = False)
    920 null_text_embeds = self.null_text_embed.to(text_encodings.dtype)
--> 922 text_encodings = torch.where(
    923     rearrange(mask, 'b n -> b n 1').clone(),
    924     text_encodings,
    925     null_text_embeds
    926 )
    928 # classifier free guidance
    930 keep_mask = prob_mask_like((batch,), 1 - cond_drop_prob, device = device)

**RuntimeError: Expected condition, x and y to be on the same device, but condition is on cuda:0 and x and y are on cuda:0 and cpu respectively**

Can you help me to figure it out?

cest-andre commented 1 year ago

@kdavidlp123

You're not using the exact same code as you're trying to set your prior and decoder to your cpu. I've only ran everything on GPU so not sure what the exact problem is. Maybe you need to also set the DALLE2 object to cpu as well.

kdavidlp123 commented 1 year ago

@cest-andre Thank you for your reply. After I put them all on cpu, it worked. And also I would like to ask, how to generate a larger picture? Does it depend on model or what?

cest-andre commented 1 year ago

@kdavidlp123 Open the decoder_config json and change "image_sizes". I was having trouble setting it to a higher value. The images would look very oddly distorted. I don't know much about the inner workings of DALLE so I suppose it cannot generalize to other dimensions without retraining. Not sure though.

4daJKong commented 1 year ago

Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth" ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

cezeilo commented 1 year ago

@ZhangxinruBIT

You didn't mention changing the keys in the decoder. This was something I mentioned and included in the code.

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

I only discovered this because I had strict=True in load_state_dict. Without the above modification, the script would error out as the keys of the pth did not match the model defined in the code. If strict is False (false by default), then no error occurs but the weights are not properly loaded so noisy results appear.

First set strict=True for both the prior and decoder. If the key mismatch occurs for the decoder and it mentions missing keys referring to clip, then paste the above code between when you first load the pth (torch.load) and before load_state_dict.

This worked for me, thank you!

YANDaoyu commented 1 year ago

Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth" ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named dalle2_pytorch cuz I found when it exits, the version of dalle2_pytorch I use will be 1.14.0, not 1.1.0. Maybe you can try it also?

fido20160817 commented 1 year ago

Hi, @tikitong sorry to bother you, where did you import"weights/prior_config.json" and "prior_latest.pth" ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named dalle2_pytorch cuz I found when it exits, the version of dalle2_pytorch I use will be 1.14.0, not 1.1.0. Maybe you can try it also?

I can run by referring to your answer. Many thanks!

JasbirCodeSpace commented 1 year ago

@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help.

cest-andre commented 1 year ago

@cest-andre can you please attach the updated code here for others to use as an initial reference, that will be a great help.

The bottom of my first comment above has the code.

https://github.com/lucidrains/DALLE2-pytorch/issues/282#issuecomment-1468429675

JasbirCodeSpace commented 1 year ago

@cest-andre I'm getting the following error while running the mentioned code:

Traceback (most recent call last): File "/home/ubuntu/dalle2/main.py", line 4, in from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/dalle2_pytorch/train_configs.py", line 34, in class TrainSplitConfig(BaseModel): File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/dalle2_pytorch/train_configs.py", line 40, in TrainSplitConfig def validate_all(cls, fields): File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/pydantic/deprecated/class_validators.py", line 222, in root_validator return root_validator()(*__args) # type: ignore File "/home/ubuntu/dalle2/venv/lib/python3.10/site-packages/pydantic/deprecated/class_validators.py", line 228, in root_validator raise PydanticUserError( pydantic.errors.PydanticUserError: If you use @root_validator with pre=False (the default) you MUST specify skip_on_failure=True. Note that @root_validator is deprecated and should be replaced with @model_validator.

For further information visit https://errors.pydantic.dev/2.0.3/u/root-validator-pre-skip

Solution: I have installed pydantic==1.10.6

esteban-rs commented 1 year ago

I initially also had trouble getting pretrained weights working properly with this repository but I resolved the problem. Not sure if you have the same problem, but I'll relay in case it helps.

First, it was helpful to turn strict=True so that I could see the discrepancies between the weights and the models. The key to fixing my problem was noticing there were parameters defined in the model which did not exist in the pth file. Seeing that the latest pth file was ~8 months old, I downgraded dalle2-pytorch to version 1.1.0.

Now the weights for prior works but the decoder pth was still missing parameters related to CLIP. I then took the state_dict from the pretrained CLIP and stuffed it inside of the decoder's pth-derived state_dict.

Finally, I had to modify the dalle2_pytorch.py file as there seemed to be a bug in the DALLE2 class' forward method. On line 2940, image_embed is not set to decoder.sample's image_embed argument which produces an error when called. This bug has been fixed in more recent repo versions. But for 1.1.0, you need to modify the code:

images = self.decoder.sample(image_embed = image_embed, text = text_cond, cond_scale = cond_scale)

Here is my script to get things running:

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.
).cpu()

print(images.shape)

for img in images:
    img = ToPILImage()(img)
    img.show()

Let me know if you have any questions. Hope this helps.

Hi, i got an error from the decoder, I supouse is the dalle-pytorch version, does it actually works with the current file in the repository? decoder.load_state_dict(decoder_model_state, strict=True) _IncompatibleKeys(missing_keys=['unets.0.ups.0.3.net.0.weight', 'unets.0.ups.0.3.net.0.bias', 'unets.0.ups.1.3.net.0.weight', 'unets.0.ups.1.3.net.0.bias', 'unets.0.ups.2.3.net.0.weight', 'unets.0.ups.2.3.net.0.bias', 'unets.0.ups.3.3.net.0.weight', 'unets.0.ups.3.3.net.0.bias'], unexpected_keys=['unets.0.ups.0.3.weight', 'unets.0.ups.0.3.bias', 'unets.0.ups.1.3.weight', 'unets.0.ups.1.3.bias', 'unets.0.ups.2.3.weight', 'unets.0.ups.2.3.bias', 'unets.0.ups.3.3.weight', 'unets.0.ups.3.3.bias'])

hanghaju commented 1 year ago

This is my code for generating images, but the generated images are very blurry. Prior model: https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior Decoder model: https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 `import torch from torchvision.transforms import ToPILImage from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2 from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")

prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys(): decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2( text = ['A high quality photo of Times Square'], cond_scale = 2. )

for img in images: img.save("out.jpg") ` 35d819b69a321f62df3c05e9a152c06

ctxya1207 commented 11 months ago

028ff7b7398f31cc93171564a3c8140 What is the reason for this error

ALLIZZWELL123 commented 9 months ago

Hi, @tikitong sorry to bother you, where did you import ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in"weights/prior_config.json" and "prior_latest.pth"

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named cuz I found when it exits, the version of I use will be 1.14.0, not 1.1.0. Maybe you can try it also?dalle2_pytorch``dalle2_pytorch

where is the cuz folder,bro?

thePOET8 commented 7 months ago

Hi, @tikitong sorry to bother you, where did you import ? I tried to download them from https://huggingface.co/laion/DALLE2-PyTorch/tree/main/prior and https://huggingface.co/laion/DALLE2-PyTorch/tree/main/decoder/v1.0.2 However, it shows error in"weights/prior_config.json" and "prior_latest.pth"

prior.load_state_dict(prior_model_state, strict=True)
decoder.load_state_dict(decoder_model_state, strict=True)

if I set strict =False, there is no problem but only generate almost random images, for example, the error shows,

RuntimeError: Error(s) in loading state_dict for DiffusionPrior: Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". Unexpected key(s) in state_dict: "net.null_text_embed".

Hi, I solved this problem by deleting the folder named cuz I found when it exits, the version of I use will be 1.14.0, not 1.1.0. Maybe you can try it also?dalle2_pytorchdalle2_pytorch ``

where is the cuz folder,bro?

I have the same problem, and I change the version to 1.1.0, there is also some error.