ryoungj / optdom

[ICLR'22] Self-supervised learning optimally robust representations for domain shift.
MIT License
23 stars 3 forks source link

About CLIP S results #1

Closed silvia1993 closed 2 years ago

silvia1993 commented 2 years ago

Hello @ryoungj,

first of all, thank you for sharing your code! This is an amazing work!

I read your paper and I would start to replicate the results that you reported in Table 1. In particular, I would replicate "CLIP S" (4th row). If I correctly understood these numbers are obtained using as feature extractor the pre-trained CLIP model (Resnet-50). With the source features extracted from this pre-trained model, an MLP is trained with a supervised contrastive loss function.

So my questions are: 1 - Did I correctly understand CLIP S meaning? 2 - Why did you use a supervised contrastive loss function instead of a standard cross-entropy loss? 3 - How can I replicate these numbers using the code that you shared?

Thank you.

ryoungj commented 2 years ago

Hi, thanks for reading the paper and raising the questions!

  1. Yes, that's correct.
  2. We actually tried both supervised contrastive loss and the standard cross-entropy loss, and in our paper we reported the number of the latter (cross-entropy) which we found worked slightly better. Also to clarify, we added a bottleneck loss described in our paper to train the MLP.
  3. To reproduce the number, please refer to https://github.com/ryoungj/optdom#finetune--evaluate-clip-on-domainbed where we also provide a bash script.
silvia1993 commented 2 years ago

Hi @ryoungj,

Thank you very much for the quick reply!

I'm still a bit confused about the differences between the rows CLIP S, CLIP S + Base, and CLIP S + CAD.

CLIP S + Base -> the pre-trained model of CLIP is used to extract the features from the source domain and these are used to train an MLP without any bottleneck.

And what about CLIP S?

Thank you again!

ryoungj commented 2 years ago

CLIP S denotes the pretrained CLIP model without any further training.

Let me know if you have any more questions!

silvia1993 commented 2 years ago

I have just one more question:

I noticed that you use the SVM on the features extracted from the MLP (mounted on the pretrained CLIP). The MLP is trained with the cross-entropy loss using also the classifier_head. So, why did you use the SVM and not the classifier_head to test on the target?

ryoungj commented 2 years ago

We chose to refit the classifier on top of frozen extracted features instead of the jointly trained classifier head because we would like to separate the learning and evaluation of representations. Thus, we would like to purely evaluate the learned representations by doing linear probing on top of them from scratch.

For the choice of linear classifier, you could use SVM, logistic regression, etc, but we found that SVM worked slightly better (for all setups).

silvia1993 commented 2 years ago

Ok clear, many thanks for all the clarifications!

silvia1993 commented 2 years ago

Hi @ryoungj,

I run the experiments "SupCLIPBottleneckBase" to replicate the results that you reported in the paper. For PACS and DomainNet I replicated them, but with OfficeHome I obtained a much higher result. Do you have any idea about this strange result?

  | PACS | DomainNet | OfficeHome -- | -- | -- | -- Baseline - (Paper) | 91.2 ± 0.3 | 46.8 ± 0.2 | 70.6 ± 0.1 Baseline Replication | 91.4 ± 0.2 | 46.4 ± 0.3 | **73.8 ± 0.3**
ryoungj commented 2 years ago

Hi,

On DomainBed, we did see the results could be very sensitive to hyperparameters and random seeds, which is also why we swept over hyperparameters and averaged results over 5 random seeds (see the script run_sweep_clip.sh).

The difference in results that you obtained could be due to the variance caused by e.g., running with fewer random seeds or even hardware difference. Could you provide some details about how you reproduce the results?

silvia1993 commented 2 years ago

I reproduced the results following exactly your instruction. In particular, for OfficeHome I used this command:

python -m domainbed.scripts.sweep_clip delete_and_launch \
       --data_dir=./datasets/ \
       --command_launcher local \
       --algorithms SupCLIPBottleneckBase \
       --datasets OfficeHome \
       --n_hparams 10 \
       --n_trials 5 \
       --skip_confirmation \
       --train_script domainbed.scripts.train_clip \
       --single_test_envs \
       --wandb_proj domain_disentanglement \
       --task 'domain_generalization'\
       --hparams '{"clip_model":"'"RN50"'","mlp_depth":2}' \
       --output_dir=./checkpoints/OfficeHome/clip_resnet/SupCLIPBottleneckBase/base \
       --wandb_group OfficeHome_SupCLIPBottleneckBase

And this is the environment that I used:

Package                 Version
----------------------- -----------
absl-py                 1.0.0
aiohttp                 3.8.1
aiosignal               1.2.0
async-timeout           4.0.2
asynctest               0.13.0
attrs                   21.4.0
backcall                0.2.0
boto3                   1.23.5
botocore                1.26.5
cachetools              4.2.4
certifi                 2022.5.18.1
charset-normalizer      2.0.12
cleverhans              4.0.0
click                   8.0.4
clip                    1.0
cloudpickle             2.1.0
compressai              1.2.0
cycler                  0.11.0
dataclasses             0.8
decorator               5.1.1
dill                    0.3.4
dm-tree                 0.1.7
docker-pycreds          0.4.0
easydict                1.9
efficientnet-pytorch    0.7.1
einops                  0.4.1
Flask                   2.0.3
Flask-SQLAlchemy        2.5.1
frozenlist              1.2.0
fsspec                  2022.1.0
ftfy                    6.0.3
future                  0.18.2
gast                    0.5.3
gitdb                   4.0.9
GitPython               3.1.18
google-auth             2.6.6
google-auth-oauthlib    0.4.6
grpcio                  1.46.3
gym                     0.24.0
gym-notices             0.0.6
idna                    3.3
idna-ssl                1.1.0
importlib-metadata      4.8.3
importlib-resources     5.4.0
ipython                 7.16.3
ipython-genutils        0.2.0
itsdangerous            2.0.1
jedi                    0.17.2
Jinja2                  3.0.3
jmespath                0.10.0
joblib                  1.1.0
kiwisolver              1.3.1
language-tool-python    2.7.1
Markdown                3.3.7
MarkupSafe              2.0.1
matplotlib              3.3.4
mnist                   0.2.2
mplcursors              0.5.1
multidict               5.2.0
munch                   2.5.0
nose                    1.3.7
numpy                   1.19.5
nvidia-ml-py3           7.352.0
oauthlib                3.2.0
opencv-python           4.5.5.64
packaging               21.3
pandas                  1.1.5
parso                   0.7.1
pathtools               0.1.2
patsy                   0.5.2
pexpect                 4.8.0
pickleshare             0.7.5
Pillow                  8.4.0
pip                     21.3.1
plotly                  5.8.0
pretrainedmodels        0.7.4
promise                 2.3
prompt-toolkit          3.0.29
protobuf                3.19.4
psutil                  5.9.1
psycopg2-binary         2.9.3
ptyprocess              0.7.0
pyasn1                  0.4.8
pyasn1-modules          0.2.8
pycodestyle             2.8.0
pyDeprecate             0.3.1
Pygments                2.12.0
pyparsing               3.0.9
python-dateutil         2.8.2
pytorch-lightning       1.5.10
pytorch-lightning-bolts 0.3.2.post1
pytorch-msssim          0.2.1
pytz                    2022.1
PyYAML                  6.0
regex                   2022.4.24
requests                2.27.1
requests-oauthlib       1.3.1
rsa                     4.8
s3transfer              0.5.2
scikit-learn            0.24.2
scipy                   1.5.4
seaborn                 0.11.2
sentry-sdk              1.5.12
setproctitle            1.2.3
setuptools              59.5.0
shortuuid               1.0.9
six                     1.16.0
sklearn                 0.0
smmap                   5.0.0
SQLAlchemy              1.3.24
SQLAlchemy-Utils        0.38.2
statsmodels             0.12.2
tenacity                8.0.1
tensorboard             2.9.0
tensorboard-data-server 0.6.1
tensorboard-plugin-wit  1.8.1
tensorflow-probability  0.16.0
threadpoolctl           3.1.0
timm                    0.5.4
torch                   1.7.1
torchmetrics            0.8.2
torchvision             0.8.2
tqdm                    4.64.0
traitlets               4.3.3
typing_extensions       4.1.1
urllib3                 1.26.9
wandb                   0.12.16
wcwidth                 0.2.5
Werkzeug                2.0.3
wheel                   0.37.1
yarl                    1.7.2
zipp                    3.6.0
ryoungj commented 2 years ago

The script seems correct to me. I'm not exactly sure about what causes the difference, as it could be more subtle issues like hardware differences (the results that I obtained is in collect_clip_results.ipynb). Did you run with CLIPPretrained and SupCLIPBottleneckCondCAD as well and see If the relative differences are similar?