Closed guarin closed 8 months ago
Thanks for the suggestion! We didn't have a lot of prior experience with SSL so we chose to match the defaults of the original SimCLR/MoCo papers. Do you know of any papers that demonstrate that no weight decay works better? I'm surprised Google/FAIR didn't find this during their hyperparameter tuning.
I don't think it is mentioned in the SimCLR paper but it is in the code here: https://github.com/google-research/simclr/blob/383d4143fd8cf7879ae10f1046a9baeb753ff438/tf2/model.py#L40-L42
BYOL does the same: https://github.com/google-deepmind/deepmind-research/blob/f5de0ede8430809180254ee957abf36ed62579ef/byol/byol_experiment.py#L191-L195
But I just noticed that you are not using LARS optimizer and in SimCLR they only did this for LARS. For the other optimizers they didn't use weight decay at all, but I am not sure if they benchmarked their code with these settings.
Yeah, PyTorch doesn't have a LARS optimizer. Let me do some digging and figure out where I found these weight decay values.
Okay, finally had time to look into this.
I don't think it is mentioned in the SimCLR paper
Weight decay is mentioned in:
For the other optimizers they didn't use weight decay at all
You are correct that weight decay is not used in the optimizer, although it is used in the loss function.
Weight decay is mentioned in:
It isn't mentioned in MoCo v2, although the code for v2 is largely the same as v1. The value of weight decay for v3 is not mentioned in the paper, just that it was used.
In the code base, weight decay is used with SGD in v1/v2, LARS in v3, and AdamW in v3.
If you want to submit a PR that removes weight decay from our SimCLR optimizer and adds it to our loss function, I would be happy to accept it. I'm a little afraid to remove it entirely though.
I think this issue can be closed. If users want to reproduce the original MoCo/SimCLR papers, they can use our current defaults. If they want to try to improve performance, they can use weight_decay=0
.
Summary
I noticed that your SSL models use weight decay for all their parameters, see for example:
https://github.com/microsoft/torchgeo/blob/0f8b0ac3eabf1e4156e7259d34b1b63e6df5d445/torchgeo/trainers/simclr.py#L280-L284
For these models it is usually better to not use weight decay for batch norm and bias parameters. We implement it like this in lightly:
https://github.com/lightly-ai/lightly/blob/c1e341222b2d199f8a23277a92417debf8caa783/benchmarks/imagenet/resnet50/simclr.py#L65-L84
Could give a small boost to your SSL models :)