lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

p2_loss_weight causing problems for prior checkpoints #158

Closed nousr closed 2 years ago

nousr commented 2 years ago

The current prior checkpoints were trained before p2_loss_weight was added. As a result, it is not possible to load these checkpoints in the latest version.

According to @Veldrovive settings strict=Falsedoes allow you to load the checkpoint--but results in garbage predictions.

Is there a way we can patch this to allow the old checkpoints to work with the latest repo?

lucidrains commented 2 years ago

@nousr so i think the version on which the model was trained on should be saved in the .pt file, and you can just make sure to pip install dalle2-pytorch=={version} before loading it

actually made sure to build this early on because i had trouble with this in other repositories

Veldrovive commented 2 years ago

The problem is that the pretrained decoder only works on >0.10.1 and the pretrained prior only works on 0.7.0. We are just looking for a workaround so that the prior and decoder can be loaded at the same time without retraining the entire prior on the newest version. Some value to put in for p2_loss_weight in the state dict or something.

lucidrains commented 2 years ago

i think it is probably best to just load with strict = False and finetune for an epoch or two

Veldrovive commented 2 years ago

Ah, that's smart. @nousr would you be able to do that?