lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

Add ability to train decoder using embedding-image pairs #63

Closed Veldrovive closed 2 years ago

Veldrovive commented 2 years ago

I am implementing a single node training script for the decoder and it seems @lucidrains has implemented a wrapper script for this purpose that is already feature-full. Currently, the forward pass is implemented as follows: https://github.com/lucidrains/DALLE2-pytorch/blob/1d5dc088109c5d606096d40bea59ff8c7b8e7d9f/dalle2_pytorch/train.py#L189-L199 This lacks the ability to substitute our own image embeddings in the case where we have precomputed embedding-image pairs. The functionality is already mostly supported by the Decoder network where image_embed can be passed to the forward method so this could be implemented by simply adding the image_embed parameter as a pass though to decoder.forward. However, it would also be convenient to make the clip model optional in the Decoder constructor. I already started on this a week ago in this branch by adding the ability to set clip_image_size and channels separately from a clip model.

There are only a few small changes that would be necessary to implement this feature so I could put together a pull request to do this.

lucidrains commented 2 years ago

@Veldrovive Hi Aidan! Indeed that is the case, and I can get this finished in the next half hour, been meaning to get around to it!

lucidrains commented 2 years ago

@Veldrovive here you go https://github.com/lucidrains/DALLE2-pytorch/releases/tag/0.0.106

Veldrovive commented 2 years ago

Great! For the DecoderTrainer are you thinking to just use kwargs for image_embed and not put a specific named parameter for it?

lucidrains commented 2 years ago

@Veldrovive yup, for wrapper i usually just forward kwargs to whatever is being wrapped (instead of using some fancy forwarding module)