Closed avipartho closed 1 year ago
Thanks for your kind words!
What does the CE loss from here stand for, i.e. which of the 4 losses from the paper it refers to?
Does the CE loss from here refer to the l_p loss in the paper? If not, which loss form the paper it refers to?
This and the loss defined on line 508 make up $l_p$ in the paper (equation 2). This is the loss for training the model to produce the [IMG] tokens at the end of "caption-like" text. Both L506 and L508 are actually the same loss, since the same caption (e.g., "A picture of a dog [IMG0]...[IMG7]") are used for retrieval and generation. This is why they have a 0.5 multiplier, so that they sum to be $l_p$.
From these lines (line1, line2, line3), it looks like all tokens that are not part of the caption text or [IMG0] have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other [IMG{r}] tokens (r={2,3,...,8})?
That's right, and the reason for this is that we force the generation of the r={2,3,...,8}
tokens as the next 7 tokens whenever the model produces [IMG0]
(since we always need all 8 tokens for generation/retrieval, so it doesn't make sense to have a partial set of the [IMG]
tokens). The embeddings of [IMG2]...[IMG8]
tokens are therefore only learnt through the other losses (in particular the generation loss $l_g$), when their embeddings/hidden states are used for computing the generation/retrieval objectives. $l_p$ doesn't affect [IMG2]...[IMG8]
tokens. So the model will never produce [IMG2]...[IMG8]
organically, but their representations are still helpful for feeding into the GILLMapper module for image generation.
Hope that makes sense!
Thanks for your quick response. Reopening this issue for another query regarding the pipeline (didn't want to unnecessarily create new issue).
If I am not wrong, this line makes the entire OPT embedding layer trainable. It is also evident from the param_count.txt
file your scripts generate. However, according to the paper only the [IMG]
embedding matrix Eimg was supposed to be trainable. Did I miss anything here?
First few lines from param_count.txt :
| model.logit_scale | True | () | 1 | | model.lm.model.decoder.embed_tokens.weight | True | (50274, 4096) | 205,922,304 |
You're right, they become trainable, which is why we zero out the gradients of the non-[IMG] embeddings here:
https://github.com/kohjingyu/gill/blob/53fdcf22952ba0e08fe2cf5b006948df7c5636aa/main.py#L578-L587
This is not super ideal, but I think it is overall cleaner than concatenating a trainable embedding matrix with a frozen one.
Thanks again. Unfortunately, I missed this section of the script.
Is it also correct to say that for the lp loss, you are considering the loss for generating each token of the input text (caption) i.e. the negative log likelihood of generating token st conditioned on s1,...,st-1 where t={1,...,T}?
Yes that's right!
In that case, I believe equation 2 is slightly misleading, as the summation goes over i
from 1
to r
there. This practically says that we are considering loss for generating all 8 [IMG]
tokens.
You're absolutely right, thanks for pointing this out! We'll fix this in the paper soon. The correct scheme should be that the loss is only considered for the first [IMG0]
token. The part about forcing generation of the remaining tokens during inference is still true.
As I was trying to train my own model and run inference using the saved checkpoint, I noticed a few things, please verify (might be helpful for other users).
GILLArgs()
only has local variables and no attributes, all of them don't get saved in the model_args.json
file unless specifically set after instantiating. One such attribute is text_emb_layers
. Turning all local variables into class attributes can solve this.main.py
script also saves the scheduler state, which pretty much saves the entire model (probable reason) and therefore results in a large checkpoint. input_embeddings.weight
whereas running the main.py
will produce a checkpoint with input_embeddings.weight
of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,
img_token_embeddings = state_dict['model.input_embeddings.weight'].cpu().detach()[-model_kwargs['num_tokens']:, :]
Thanks for sharing this!
- The pretrained checkpoint provided in this codebase has a shape of (8,4096) for
input_embeddings.weight
whereas running themain.py
will produce a checkpoint withinput_embeddings.weight
of shape (50274, 4096). Looks like the provided checkpoint contains only the trainable [IMG] token embeddings. This requires either changing this line or this line to run inference with the produced checkpoint. For example,
You're right, and I also realized that I hadn't uploaded the script used to prune the checkpoints (keeping just the trained weights, and discarding the pretrained model weights). I just did that here: https://github.com/kohjingyu/gill/blob/main/scripts/prune_model_ckpt.py
I think this is essentially the same as the changes you probably made locally, though I haven't tested this script in a while.
Thanks for sharing the script! Just noticed a few things -
num_tokens
? Please verify.models.py
)python scripts/prune_model_ckpt.py runs/gill_exp
, given the location of the script.share_ret_gen
? I could not find any use of this in the models.py
, validate.py
or main.py
script.Thanks for the notes! Sorry about this, it's what happens when you don't test before you upload...
These arguments no longer exist in the current version. Could it be that you probably coalesced them into num_tokens? Please verify.
That's right.
What's the use of share_ret_gen? I could not find any use of this in the models.py, validate.py or main.py script.
share_ret_gen
doesn't exist anymore, I think it was something used during debugging previously. I've updated the script as such, hopefully it works as expected now. Thanks for your help in debugging this!
Another small update. I could not find warmup-scheduler==0.3.2
(as mentioned in the requirements.txt file), the current available version is probably 0.3. Will it be compatible with your scripts? (I can verify that the training continues with this version)
Ah, looks like it should be pip install git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
instead. The link you provided should still work though.
I have another question. As mentioned above, the lp loss includes the negative log likelihood (NLL) of generating each token of the input text (caption). Did you find this helpful for the overall model performance? I am asking this because from the name and purpose of this loss, I would assume that it was intended to only consider the NLL of generating [IMG]
tokens.
I have not run this particular ablation, sorry. I would guess that it does not have a significant effect on performance on the tasks we evaluated on.
Hi,
I read your paper and it was a great work! Thanks for sharing your codebase with the community. As I was going through your codes, I came across a few places, where I would greatly appreciate your explanations/suggestions. Here are my questions -
[IMG0]
have been set to -100 to be ignored from calculating loss. Is my understanding correct? If it is, how are we learning embeddings for other[IMG{r}]
tokens (r={2,3,...,8}
)?