kohjingyu / fromage

🧀 Code and models for the ICML 2023 paper "Grounding Language Models to Images for Multimodal Inputs and Outputs".
https://jykoh.com/fromage
Apache License 2.0
466 stars 34 forks source link

The reproduction of FROMAGe training #22

Closed Ziyang412 closed 1 year ago

Ziyang412 commented 1 year ago

Hi! I am trying to reproduce the training of FROMAGe model using CC3M dataset, and the final output of CC3M val seems normal:

Computing similarity between torch.Size([12856, 256]) and torch.Size([12856, 256]).                                                      
 * Time 9.645 Loss 3.445 Acc@1 46.411 Acc@5 69.867 BLEU@4 0.063       

While I was trying to eval on VisDial dataset (after the ckpt pruning), I get the error below:

error_fromage

I print the dimension of these vector, it seems that the saved "ret_input_embeddings.weight" is [1,4096] dim instead of [4096].

To tackle this, I change the code in https://github.com/kohjingyu/fromage/blob/92c6d6f6ea9cea38f0b0a12bcdb0cf3915d0e774/fromage/models.py#L679

to the code below (add a squeeze)

model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].squeeze(0).cpu().detach())

I evaluate on IT2T and get reasonable results, however, while I test on T2I settings, the results are even worst than random guessing. top-k, k=1, acc=0.00000 top-k, k=5, acc=0.00097 top-k, k=10, acc=0.00242

Could you help me with this? Thank you so much!

kohjingyu commented 1 year ago

Hi,

That's strange! It seems to suggest that the rest of the model is loaded correctly, except for the ret_embeddings weight (since this computes the text embeddings used in retrieval). Could you help to check two things?

1) If you load the pretrained model in this repo for running VisDial evals, does the number look reasonable? 2) If you use the saved checkpoint without pruning, does it work? You can probably do this by commenting out the following two lines:

https://github.com/kohjingyu/fromage/blob/92c6d6f6ea9cea38f0b0a12bcdb0cf3915d0e774/fromage/models.py#L678-L679

Ziyang412 commented 1 year ago

I find several difference between the model.args in my runs log and the one in /fromage_model directory, don't know whether this effect the weight loading for [RET] token. If so, is there any way I can fix it?

image
Ziyang412 commented 1 year ago

Hi,

That's strange! It seems to suggest that the rest of the model is loaded correctly, except for the ret_embeddings weight (since this computes the text embeddings used in retrieval). Could you help to check two things?

  1. If you load the pretrained model in this repo for running VisDial evals, does the number look reasonable?
  2. If you use the saved checkpoint without pruning, does it work? You can probably do this by commenting out the following two lines:

https://github.com/kohjingyu/fromage/blob/92c6d6f6ea9cea38f0b0a12bcdb0cf3915d0e774/fromage/models.py#L678-L679

Thank you for the reply.

For 1, yes, I can reproduce reasonable results using the pretrained model.

For 2, no, I tried but receive the same results with the pruned ckpt.

Hope to get fix soon, thank you in advance!

Ziyang412 commented 1 year ago

BTW, this is the training script I used, in case it helps. The only thing I change in the main.py code is the GPU number (https://github.com/kohjingyu/fromage/blob/92c6d6f6ea9cea38f0b0a12bcdb0cf3915d0e774/main.py#L190C1-L190C1), I set the ngpus_per_node as 4.

export NCCL_P2P_DISABLE=1 randport=$(shuf -i8000-9999 -n1) # Generate a random port number python -u main.py \ --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \ --multiprocessing-distributed --world-size 1 --rank 0\ --dataset=cc3m --val-dataset=cc3m \ --opt-version='facebook/opt-6.7b' --visual-model='openai/clip-vit-large-patch14' \ --exp_name='fromage_train_exp_6.29_bs120_lr2_valbs80' --image-dir='/data/cc3m_dl/conceptual_caption/' --log-base-dir='runs/' \ --batch-size=120 --val-batch-size=80 --learning-rate=0.0002 --precision='bf16' --print-freq=100

kohjingyu commented 1 year ago

Thanks for sharing that! I managed to reproduce the problem. This happens because when training with DDP, we have a module. prefix in the state_dict. So we were not restoring the weights at all, since they have different names (and we set strict=False to load pretrained OPT/CLIP weights). I've just pushed a commit to update the fromage/prune_model_ckpt.py script to remove the prefix:

https://github.com/kohjingyu/fromage/blob/51fb06acf72f7abacd6da49cbc8c09a56826fbd0/fromage/prune_model_ckpt.py#L21-L22

If you rerun the script and the evals, I think it should work as expected now, but please let me know if it doesn't!

Ziyang412 commented 1 year ago

Yeah, it works! Thank you so much!