lucidrains / DALLE-pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch
MIT License
5.55k stars 643 forks source link

CogView Think Image and Text Should be weighted the same #266

Open afiaka87 opened 3 years ago

afiaka87 commented 3 years ago

In the cogview paper they claim that by giving the text as much importance they achieve a better result. They "hypothesize" that this is because the transformer is learning both how to predict images from text and how logic/knowledge/info works in general. As far as i can tell - it isn't mentioned again, unfortunately.

At any rate - perhaps we should run some tests with the --img_loss_weight parameter set to 1?

afiaka87 commented 3 years ago

I've done a run using --loss_img_weight 1 and setting the presently hidden stable parameter to True in the DALLE initialization.

Here is a W&B report. I'm not tracking text and img loss separately although the average loss seems to converge much quicker; I assume that has something to do with the weighting exploring a "different loss curve". Happy to be corrected.

https://wandb.ai/dalle-pytorch-replicate/illustrations_imagenetvqgan/reports/Snapshot-Jun-6-2021-12-43pm--Vmlldzo3NTYxMjE?accessToken=hhov3b0wsf56tts63wx4qijkl4pnpiogizoh6a32bdctvngy5rvwtygjqpfyl1uj

@lucidrains @robvanvolt @rom1504 @gabriel_syme @janEbert @mehdidc

afiaka87 commented 3 years ago

Here is the byte pair encoding I used. Vocab size of 8192 covering 99.999% of all unique characters in about 6 million captions from conceptual captions. Perhaps overkill for these illustrations actually - which have a more limited vocabulary.

https://www.dropbox.com/s/ay01p8zegfwof8t/variety.bpe

Here is a checkpoint from the most recent iteration (still training). Decided to name the checkpoint "royalty free" as the dataset largely consists of 570,000 royalty free illustrations from the conceptual captions dataset.

Inspired by CogView

https://www.dropbox.com/s/drpkcmr6b3zbftm/royalty_free.pt

afiaka87 commented 3 years ago

Fun experiment - didn't really pan out. There seems to be a mild tradeoff where the generations match maybe just one word in the caption rather than the full caption. The loss was continuing to go down still; but I'm not sure what else they were doing in the paper that may have made text loss more impactful.

janEbert commented 3 years ago

Seems you are close to figuring it out, though! Do you think training further would fix the issue where too few words are attended?

afiaka87 commented 3 years ago

@janEbert is the relation between the text sequence length at play here? I've got some preliminary test runs suggesting this is the case.

Increasing the text_seq_len causes loss to converge much higher.

For instance using a text_seq_len of 384:

If you decrease the image weight to 4, training converges to a a lower value of ~5.5

By decreasing the img loss weight to 1; I finally (very quickly) converge to the loss I'm used to on a "diverse enough" dataset (for lack of more rigorous words).

Screen Shot 2021-06-11 at 11 00 03 PM

CogView used 1024 tokens. Re-reading their wording; they don't necessarily criticize the weighting but moreso just criticize open ai for assuming the impact of text on the loss was merely "auxiliary". So perhaps the difference in weighting here is due to the training on Chinese text? I'm regrettably uninformed about the differences in English/Chinese and how that relates to training language models.

Is it possible that OpenAI's dataset simply didn't benefit much from a lengthy sequence because the average caption of their dataset didn't tend to be that high in the first place? As such it would be wise to scale other parts of the model perhaps?

I'll follow up with reconstructions in a bit.

janEbert commented 3 years ago

Interesting thought about the differences in training regarding Chinese and English. Although, assuming the same word embeddings for both languages, translated text should be the same in terms of amount of tokens for both languages, right?

I haven't read CogView yet so I have absolutely no intuition at the moment, either. Sorry!

afiaka87 commented 3 years ago

@janEbert

I think they may have been trying to underfit the data for some reason because of issues which are perhaps more apparent at the scale OpenAI was operating at. There's a good deal of hand-waving/we-did-it-cus-they-did-it in the dalle paper though so I need to revisit their reasoning for weighting the text-predicts-image loss more.

The intuition provided by CogView is equally hand-waving in my opinion and I wouldn't be surprised if the same weighting would hit the same loss curve over roughly the same amount of time due to scaling laws.

If I may commit a bit of 'academic fraud' - my new intuition just based on running a whole bunch of generations with image/text loss the same; is that there is a relationship between weighting and the noise of your dataset. If you have very lengthy captions and they are all very accurate and concise; use a lower text-predicts-image weight and a higher text_seq_len. The text portion of the transformer will indeed benefit from "being able to learn more about the language modality"; although probably mostly due to the increase in text_seq_len.

A good example is Open Images Localized Annotations (~500k image-text pairs). I trained on that dataset with a text seq len in excess of 600. With the weighting at seven, the loss obviously doesn't go down nearly as quickly; but you can get some pretty great looking outliers a lot earlier on. When I weighted them the same (please remember, all intuition/hypothesis) the loss goes down what it more or less settles on within an order of magnitude much faster. The images look better on average perceptually (although those same outlier examples do tend to look worse - they look worse according to errors you can see in the caption i.e. they look worse perceptively not conceptually). Does that make sense?

I haven't experimented with a "noisy" dataset - but i bet a good example would be something like wikipedia articles from WIT. Open Images Localized Annotations is literally transcribed spoken word of a human being describing an image as they look at it. Sure - it has errors and stuff - but most of the words in the dataset are actually directly about what is in the image itself. Not a discussion of the history of the subject in the image or what have you. As such - you may want to underfit the transformer's text-predicts-text loss by increasing the text-predicts-image weight on such a dataset. Perhaps this allows the model to underfit the language modality and sort of be "more open to interpretation" so to speak.

Again - who knows; this stuff would all be fascinating to visualize if we had a heatmap of the attention heads. Unfortunately I have no idea how to implement that.

janEbert commented 3 years ago

Super interesting insights! I also agree about the scale, it's probably hard to compare these super-large-scale models and the "casual" ones.

Visualization also sounds like a cool feature for understanding. Maybe we can extend BertViz at some point?