lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

Implementing some features from cogview #263

Open rom1504 opened 3 years ago

rom1504 commented 3 years ago

Cogview has these fine tuning abilities :

I think they are all pretty cool and seem simple enough in the paper I wonder if we could implement them here.

afiaka87 commented 3 years ago

@lucidrains I heard through the grapevine that you're working on some of this stuff? Is that correct?

afiaka87 commented 3 years ago

this is related https://github.com/lucidrains/DALLE-pytorch/issues/266

rom1504 commented 3 years ago

From cogview paper ( https://arxiv.org/abs/2105.13290 )

Super resolution

We first finetune a super-resolution model from 16 × 16 image tokens to 32 × 32 tokens, and then use it to magnify generated 32 × 32 tokens to 64 × 64 tokens by a center-continuous sliding-window strategy in Figure 5, finally resulting in an image of 512 × 512 pixels. To prepare data, we crop about 2 million 256 × 256 patches and downsample them to 128 × 128. Then we get 16 × 16 and 32 × 32 sequence pairs after tokenization for different resolution. The pattern of finetuning sequence is “[ROI1] text tokens [BASE][BOI1] 256 image tokens [EOI1][ROI2][BASE] [BOI2] 1024 image tokens [EOI2]”, exceeding the max position embedding index 1088. As a solution, we recount the position index from 0 at [ROI2]. In practice, the model can distinguish the two images well, probably based on whether they can attend to a [ROI2] in front.

Captioning

To finetune CogView for image captioning is straightforward: exchanging the order of text and image tokens in the input sequences. Since the model has already learnt the corresponding relationships between text and images, reversing the generation is not hard.

Image text score

We propose the Caption Score (CapS) to evaluate the correspondence between images and text. More specifically, CapS(x, t) = |t| q Π |t| i=0p(ti |x, t0:i−1), where t is a sequence of text tokens and x is the image. log CapS(x, t) is the cross-entropy loss for the text tokens, and this method can be seen as an adaptation of inverse prompting [49] for text-to-image generation. Finally, images with the highest CapS are chosen.

I like the simplicity of all 3 methods.

I might give it a try, seems too simple not to try :)

afiaka87 commented 3 years ago

@rom1504 perhaps now it might be easier?

rom1504 commented 3 years ago

yeah most likely indeed!

rom1504 commented 3 years ago

https://github.com/THUDM/CogView/tree/main/finetune although I am not sure if these fine tuning methods are implemented in cogview repo I don't think it's particularly difficult to implement though