kohjingyu / fromage

🧀 Code and models for the ICML 2023 paper "Grounding Language Models to Images for Multimodal Inputs and Outputs".
https://jykoh.com/fromage
Apache License 2.0
474 stars 35 forks source link

How are the inputs arranged for in-context retrieval evaluation? #17

Closed ys-zong closed 1 year ago

ys-zong commented 1 year ago

Hi, thanks for the nice work and code! I wonder how the input tokens are arranged for in-context retrieval evaluation? E.g. in Table 1 "5 captions, 4 images", are the input to the model like ([IMG][Caption])x4+[Caption] or [IMG]X4+[Caption]x5 or something else? I guess it's the former one? Also, how are the in-context prompts selected - are they randomly selected or the pairs that have similar semantic features with the query caption are selected? It would be great if you could provide a code snippet for it. Many thanks!

kohjingyu commented 1 year ago

Hi, thanks for your interest! For Table 1, the input is: ([Caption][IMG])x4+[Caption]+[RET]. We basically want to format the in-context examples similar to the retrieval, so we do [Caption][IMG] throughout.

I've uploaded the code for reproducing Table 1 here: https://github.com/kohjingyu/fromage/blob/main/evals/VIST_Contextual_Image_Retrieval.ipynb (there's also a .py script in the same folder, which does the same thing).

Note that there was a bug in a previous version of the code, so the results that this script produces (R@1 of 18.2) are actually better than the ones we have in the current version of our paper (R@1 of 15.6), and this version should reflect the correct way to call the model for retrieval. We'll update the paper shortly to reflect this.

Hope that helps!

ys-zong commented 1 year ago

Got it. Thank you very much!

ys-zong commented 1 year ago

Hi, just a quick follow-up question. The weights of ret_input_embeddings in this line doesn't seem to be in model.state_dict() when saving the model. I wonder if I should manually extract the ret_input_embeddings from the embedding layer with the ret token index?

kohjingyu commented 1 year ago

Ah yes, sorry about that. We pruned the checkpoint to remove the pretrained weights/embeddings so that it was small enough to upload to GitHub :)

If you are using a model saved with L379 in main.py, you should be able to load the checkpoint by removing these two lines:

https://github.com/kohjingyu/fromage/blob/7d735cd334a21e307929cdb36d845cd00599f54a/fromage/models.py#L678-L679

Simply put, if you did not prune the checkpoint, torch.load

https://github.com/kohjingyu/fromage/blob/7d735cd334a21e307929cdb36d845cd00599f54a/fromage/models.py#L676

should be sufficient to load it.

I wonder if I should manually extract the ret_input_embeddings from the embedding layer with the ret token index?

You can also do this if you would like to prune the checkpoint. I've uploaded the script we used for this here (although it's untested). I think it would probably be easier if you just commented out the above lines for a custom checkpoint, though.

Hope that helps!

ys-zong commented 1 year ago

Thank you for the quick reply!