Open pragnakalpdev7 opened 3 years ago
@pragnakalpdev7 oops! has been fixed! https://github.com/lucidrains/DALLE-pytorch/commit/e1a847ce8286c0aa62bf501ddb7fb16cb3e86a69
Thank you for the quick response @lucidrains ,
We have already tried the changes suggested by you, in which facing below error :
Traceback (most recent call last): File "generate.py", line 65, in <module> dalle = DALLE(vae = vae, **dalle_params).cuda() TypeError: type object got multiple values for keyword argument 'vae'
@pragnakalpdev7 removing this line should do the trick. The vae will get passed in properly later on.
edit: Oh i'm sorry i confused your issue with one I was having in train_dalle.py. I'm not seeing the same bug in generate but can't run it myself currently.
Again, can't verify myself but if you're using a pretrained VAE (OpenAIDiscreteVAE, for instance), then I believe that the value for 'vae' will actually be none. The dictionary **dalle_params may contain the vae definition. As such, you can just remove that argument from your call to DALLE:
dalle = DALLE(**dalle_params).cuda()
Does that work?
Hello @afiaka87,
We have already tried this trick, but not getting good results. We think maybe we are missing something.
@pragnakalpdev7 one more patch! https://github.com/lucidrains/DALLE-pytorch/commit/f2b02bab902f4acd435b0eadb4442e2788c0fcae
Hi, @lucidrains,
We have trained the model using OpenAIs pretrained VAE. We have use the train_dalle.py file for training.
python train_dalle.py --image_text_folder=./dataset_images
After training, we have tried to generate the result using the saved model by following command.
python generate.py --dalle_path=/content/DALLE-pytorch/dalle-final.pt --text='bird has wings that are brown and has a red crown' --num_images=20
On running the above command we are facing the error:
Traceback (most recent call last): File "generate.py", line 55, in <module> vae = DiscreteVAE(**vae_params) TypeError: type object argument after ** must be a mapping, not NoneType