microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.76k stars 347 forks source link

SSL Weight Decay #1828

Closed guarin closed 8 months ago

guarin commented 9 months ago

Summary

I noticed that your SSL models use weight decay for all their parameters, see for example:

https://github.com/microsoft/torchgeo/blob/0f8b0ac3eabf1e4156e7259d34b1b63e6df5d445/torchgeo/trainers/simclr.py#L280-L284

For these models it is usually better to not use weight decay for batch norm and bias parameters. We implement it like this in lightly:

https://github.com/lightly-ai/lightly/blob/c1e341222b2d199f8a23277a92417debf8caa783/benchmarks/imagenet/resnet50/simclr.py#L65-L84

Could give a small boost to your SSL models :)

adamjstewart commented 9 months ago

Thanks for the suggestion! We didn't have a lot of prior experience with SSL so we chose to match the defaults of the original SimCLR/MoCo papers. Do you know of any papers that demonstrate that no weight decay works better? I'm surprised Google/FAIR didn't find this during their hyperparameter tuning.

guarin commented 9 months ago

I don't think it is mentioned in the SimCLR paper but it is in the code here: https://github.com/google-research/simclr/blob/383d4143fd8cf7879ae10f1046a9baeb753ff438/tf2/model.py#L40-L42

BYOL does the same: https://github.com/google-deepmind/deepmind-research/blob/f5de0ede8430809180254ee957abf36ed62579ef/byol/byol_experiment.py#L191-L195

But I just noticed that you are not using LARS optimizer and in SimCLR they only did this for LARS. For the other optimizers they didn't use weight decay at all, but I am not sure if they benchmarked their code with these settings.

adamjstewart commented 9 months ago

Yeah, PyTorch doesn't have a LARS optimizer. Let me do some digging and figure out where I found these weight decay values.

adamjstewart commented 9 months ago

Okay, finally had time to look into this.

SimCLR

I don't think it is mentioned in the SimCLR paper

Weight decay is mentioned in:

For the other optimizers they didn't use weight decay at all

You are correct that weight decay is not used in the optimizer, although it is used in the loss function.

MoCo

Weight decay is mentioned in:

It isn't mentioned in MoCo v2, although the code for v2 is largely the same as v1. The value of weight decay for v3 is not mentioned in the paper, just that it was used.

In the code base, weight decay is used with SGD in v1/v2, LARS in v3, and AdamW in v3.

adamjstewart commented 9 months ago

If you want to submit a PR that removes weight decay from our SimCLR optimizer and adds it to our loss function, I would be happy to accept it. I'm a little afraid to remove it entirely though.

adamjstewart commented 8 months ago

I think this issue can be closed. If users want to reproduce the original MoCo/SimCLR papers, they can use our current defaults. If they want to try to improve performance, they can use weight_decay=0.