facebookresearch / vissl

VISSL is FAIR's library of extensible, modular and scalable components for SOTA Self-Supervised Learning with images.
https://vissl.ai
MIT License
3.24k stars 330 forks source link

How to do Supervised Contrastive Learning? #478

Open Pedrexus opened 2 years ago

Pedrexus commented 2 years ago

Hello, everyone.

❓ How to do Supervised Contrastive Learning using VISSL

Is there a way to perform Supervised Contrastive Learning in VISSL? Basically, I'd like to use the classes I already have as "positives" instead of generating them automatically.

I found the comment below when looking through the code and on the docs, so I wonder if setting LABEL_TYPE: standard would be the solution.

# There are three types of label_type (data labels): "standard",
# "sample_index", and "zero". "standard" uses the labels associated
# with a data set (e.g. directory names). "sample_index" assigns each
# sample a label that corresponds to that sample's index in the
# dataset (first sample will have label == 0, etc.), and is used for
# SSL tasks in which the label is arbitrary. "zero" assigns
# each sample the label == 0, which is necessary when using the
# CutMixUp collator because of the label smoothing that is built in
# to its functionality.

Thanks a lot for this amazing tool!

iseessel commented 2 years ago

Hi @Pedrexus, we don't support Supervised Contrastive Learning, but based on a quick read of it, I think all that is required is adding a custom loss function -- as this loss is simply a function of the mini-batch outputs and the targets.

You can see how to do that here: https://vissl.readthedocs.io/en/v0.1.6/extend_modules/losses.html.

I also found a pytorch implementation here: https://github.com/HobbitLong/SupContrast/blob/master/losses.py if you want to reference it for a VISSL implementation.

Pedrexus commented 2 years ago

Thanks a lot @iseessel. I will try to make it as a custom loss then!