microsoft / VQ-Diffusion

Official implementation of VQ-Diffusion
MIT License
894 stars 62 forks source link

cannot reproduce inference using coco_pretrained.pth #12

Open Lizb6626 opened 2 years ago

Lizb6626 commented 2 years ago

When I run the code below, I get an AttributeError.

VQ_Diffusion_model = VQ_Diffusion(config='OUTPUT/pretrained_model/config_text.yaml', path='OUTPUT/pretrained_model/coco_learnable.pth')
VQ_Diffusion_model.inference_generate_sample_with_condition("A group of elephants walking in muddy water", truncation_rate=0.86, save_root="RESULT", batch_size=4)

The whole error traceback is as follows.

Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
/home/lizhibing/anaconda3/envs/vqdif/lib/python3.9/site-packages/torchvision/transforms/transforms.py:280: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
  warnings.warn(
{'overall': {'trainable': '370.77M', 'non_trainable': '126.29M', 'total': '497.06M'}, 'content_codec': {'trainable': '0', 'non_trainable': '65.8M', 'total': '65.8M'}, 'condition_codec': {'trainable': '0', 'non_trainable': '0', 'total': '0'}, 'transformer': {'trainable': '370.77M', 'non_trainable': '60.49M', 'total': '431.26M'}}
Model missing keys:
 []
Model unexpected keys:
 ['transformer.empty_text_embed']
Evaluate EMA model
Traceback (most recent call last):
  File "/home/lizhibing/repo/VQ-Diffusion/inference_VQ_Diffusion.py", line 184, in <module>
    VQ_Diffusion_model.inference_generate_sample_with_condition("A group of elephants walking in muddy water", truncation_rate=1.0, save_root="RESULT", batch_size=4, guidance_scale=3.0)
  File "/home/lizhibing/repo/VQ-Diffusion/inference_VQ_Diffusion.py", line 126, in inference_generate_sample_with_condition
    model_out = self.model.generate_content(
  File "/home/lizhibing/anaconda3/envs/vqdif/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
  File "/home/lizhibing/repo/VQ-Diffusion/image_synthesis/modeling/models/dalle.py", line 164, in generate_content
    cf_cond_emb = self.transformer.empty_text_embed.unsqueeze(0).repeat(batch_size, 1, 1)
  File "/home/lizhibing/anaconda3/envs/vqdif/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in __getattr__
    raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'DiffusionTransformer' object has no attribute 'empty_text_embed'
Lizb6626 commented 2 years ago

I changed the default value of learnable_cf from False to True in DiffusionTransformer Constructor (diffusion_transformer.py line82) and fixed this problem. The config file or something should be modified to be compatible with the classifier-free version.

fido20160817 commented 1 year ago

you should set "learnable_cf = False" in function inference_generate_sample_with_condition(...)