kohjingyu / gill

🐟 Code and models for the NeurIPS 2023 paper "Generating Images with Multimodal Language Models".
https://jykoh.com/gill
Apache License 2.0
433 stars 38 forks source link

A few questions about the training pipeline #5

Closed avipartho closed 1 year ago

avipartho commented 1 year ago

Hi,

I read your paper and it was a great work! Thanks for sharing your codebase with the community. As I was going through your codes, I came across a few places, where I would greatly appreciate your explanations/suggestions. Here are my questions -

kohjingyu commented 1 year ago

Thanks for your kind words!

What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?

Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?

This and the loss defined on line 508 make up $l_p$ in the paper (equation 2). This is the loss for training the model to produce the [IMG] tokens at the end of "caption-like" text. Both L506 and L508 are actually the same loss, since the same caption (e.g., "A picture of a dog [IMG0]...[IMG7]") are used for retrieval and generation. This is why they have a 0.5 multiplier, so that they sum to be $l_p$.

From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?

That's right, and the reason for this is that we force the generation of the r={2,3,...,8} tokens as the next 7 tokens whenever the model produces [IMG0] (since we always need all 8 tokens for generation/retrieval, so it doesn't make sense to have a partial set of the [IMG] tokens). The embeddings of [IMG2]...[IMG8] tokens are therefore only learnt through the other losses (in particular the generation loss $l_g$), when their embeddings/hidden states are used for computing the generation/retrieval objectives. $l_p$ doesn't affect [IMG2]...[IMG8] tokens. So the model will never produce [IMG2]...[IMG8] organically, but their representations are still helpful for feeding into the GILLMapper module for image generation.

Hope that makes sense!

avipartho commented 1 year ago

Thanks for your quick response. Reopening this issue for another query regarding the pipeline (didn't want to unnecessarily create new issue).

If I am not wrong, this line makes the entire OPT embedding layer trainable. It is also evident from the param_count.txt file your scripts generate. However, according to the paper only the [IMG] embedding matrix Eimg was supposed to be trainable. Did I miss anything here?

First few lines from param_count.txt :

Module | Trainable | Shape | Param Count |

| model.logit_scale | True | () | 1 | | model.lm.model.decoder.embed_tokens.weight | True | (50274, 4096) | 205,922,304 |

kohjingyu commented 1 year ago

You're right, they become trainable, which is why we zero out the gradients of the non-[IMG] embeddings here:

https://github.com/kohjingyu/gill/blob/53fdcf22952ba0e08fe2cf5b006948df7c5636aa/main.py#L578-L587

This is not super ideal, but I think it is overall cleaner than concatenating a trainable embedding matrix with a frozen one.

avipartho commented 1 year ago

Thanks again. Unfortunately, I missed this section of the script.

Is it also correct to say that for the lp loss, you are considering the loss for generating each token of the input text (caption) i.e. the negative log likelihood of generating token st conditioned on s1,...,st-1 where t={1,...,T}?

kohjingyu commented 1 year ago

Yes that's right!

avipartho commented 1 year ago

In that case, I believe equation 2 is slightly misleading, as the summation goes over i from 1 to r there. This practically says that we are considering loss for generating all 8 [IMG] tokens.

kohjingyu commented 1 year ago

You're absolutely right, thanks for pointing this out! We'll fix this in the paper soon. The correct scheme should be that the loss is only considered for the first [IMG0] token. The part about forcing generation of the remaining tokens during inference is still true.

avipartho commented 1 year ago

As I was trying to train my own model and run inference using the saved checkpoint, I noticed a few things, please verify (might be helpful for other users).

kohjingyu commented 1 year ago

Thanks for sharing this!

  • The pretrained checkpoint provided in this codebase has a shape of (8,4096) for input_embeddings.weight whereas running the main.py will produce a checkpoint with input_embeddings.weight of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,

You're right, and I also realized that I hadn't uploaded the script used to prune the checkpoints (keeping just the trained weights, and discarding the pretrained model weights). I just did that here: https://github.com/kohjingyu/gill/blob/main/scripts/prune_model_ckpt.py

I think this is essentially the same as the changes you probably made locally, though I haven't tested this script in a while.

avipartho commented 1 year ago

Thanks for sharing the script! Just noticed a few things -

kohjingyu commented 1 year ago

Thanks for the notes! Sorry about this, it's what happens when you don't test before you upload...

These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.

That's right.

What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.

share_ret_gen doesn't exist anymore, I think it was something used during debugging previously. I've updated the script as such, hopefully it works as expected now. Thanks for your help in debugging this!

avipartho commented 1 year ago

Another small update. I could not find warmup-scheduler==0.3.2 (as mentioned in the requirements.txt file), the current available version is probably 0.3. Will it be compatible with your scripts? (I can verify that the training continues with this version)

kohjingyu commented 1 year ago

Ah, looks like it should be pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git instead. The link you provided should still work though.

avipartho commented 1 year ago

I have another question. As mentioned above, the lp loss includes the negative log likelihood (NLL) of generating each token of the input text (caption). Did you find this helpful for the overall model performance? I am asking this because from the name and purpose of this loss, I would assume that it was intended to only consider the NLL of generating [IMG] tokens.

kohjingyu commented 1 year ago

I have not run this particular ablation, sorry. I would guess that it does not have a significant effect on performance on the tasks we evaluated on.