LijieFan / LaCLIP

[NeurIPS 2023] Text data, code and pre-trained models for paper "Improving CLIP Training with Language Rewrites"
BSD 2-Clause "Simplified" License
261 stars 8 forks source link

Unable to reproduce CLIP Baselines on CC3M with original captions #6

Closed HaniItani closed 10 months ago

HaniItani commented 11 months ago

Hello,

Thank you for sharing your great work!

I'm having issues reproducing your CLIP baselines for CC3M using the original captions. I suspect you are using the SLIP codebase to train your models for CC3M and CC12M, and I'm doing the same. I also noticed that the difference between your CLIP implementation and theirs is the qk_norm in the attention blocks in the vision transformer.

I've implemented your evaluation pipeline and managed to reproduce the linear probing numbers you report for some of the datasets using your provided checkpoint. You can see that the gap between my reproduction attempt and yours is substantial. Note that best ImageNet zero shot I got in my reproduction is 15.2 which is comparable to your reported 15.8.

image

Can you please advise? Thanks!

HobbitLong commented 11 months ago

Hi, @HaniItani,

Yes, we modify from SLIP repo. However, their all_gather operator doesn't have gradient back-propagation, you need to switch to another one that have gradient back-propagation, e.g. torch.distributed.nn.all_gather.

Yes, we used qk_norm. Our initial trial shows it improves the training stability a bit so we kept it. Performance-wise, I don't think it caused noticeable differences.

For the linear probing, we followed the setup in, e.g., section A.6 of this paper.

HaniItani commented 11 months ago

Hi @HobbitLong,

Thank you very much for your response. Can you please elaborate on this fix? Is it as easy as replacing this line here with torch.distributed.nn.all_gather? I also see SLIP code base has gather with grad function here. Is it the same?

Your reply is very much appreciated.

LijieFan commented 10 months ago

Hi @HaniItani , Thanks again for your interest, please refer to our loss implementation