lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.58k stars 643 forks source link

Getting "type object argument after ** must be a mapping" error in image generation #77

Open pragnakalpdev7 opened 3 years ago

pragnakalpdev7 commented 3 years ago

Hi, @lucidrains,

We have trained the model using OpenAIs pretrained VAE. We have use the train_dalle.py file for training. python train_dalle.py --image_text_folder=./dataset_images

After training, we have tried to generate the result using the saved model by following command. python generate.py --dalle_path=/content/DALLE-pytorch/dalle-final.pt --text='bird has wings that are brown and has a red crown' --num_images=20

On running the above command we are facing the error: Traceback (most recent call last): File "generate.py", line 55, in <module> vae = DiscreteVAE(**vae_params) TypeError: type object argument after ** must be a mapping, not NoneType

lucidrains commented 3 years ago

@pragnakalpdev7 oops! has been fixed! https://github.com/lucidrains/DALLE-pytorch/commit/e1a847ce8286c0aa62bf501ddb7fb16cb3e86a69

pragnakalpdev7 commented 3 years ago

Thank you for the quick response @lucidrains ,

We have already tried the changes suggested by you, in which facing below error : Traceback (most recent call last): File "generate.py", line 65, in <module> dalle = DALLE(vae = vae, **dalle_params).cuda() TypeError: type object got multiple values for keyword argument 'vae'

afiaka87 commented 3 years ago

@pragnakalpdev7 removing this line should do the trick. The vae will get passed in properly later on.

https://github.com/lucidrains/DALLE-pytorch/blob/720a6060084e27a99f343802d3799a9fa8797ec6/train_dalle.py#L78

edit: Oh i'm sorry i confused your issue with one I was having in train_dalle.py. I'm not seeing the same bug in generate but can't run it myself currently.

afiaka87 commented 3 years ago

Again, can't verify myself but if you're using a pretrained VAE (OpenAIDiscreteVAE, for instance), then I believe that the value for 'vae' will actually be none. The dictionary **dalle_params may contain the vae definition. As such, you can just remove that argument from your call to DALLE:

dalle = DALLE(**dalle_params).cuda()

Does that work?

pragnakalpdev7 commented 3 years ago

Hello @afiaka87,

We have already tried this trick, but not getting good results. We think maybe we are missing something.

lucidrains commented 3 years ago

@pragnakalpdev7 one more patch! https://github.com/lucidrains/DALLE-pytorch/commit/f2b02bab902f4acd435b0eadb4442e2788c0fcae