vturrisi / solo-learn

solo-learn: a library of self-supervised methods for visual representation learning powered by Pytorch Lightning
MIT License
1.39k stars 181 forks source link

Remove dim=1 parameters from LARS updates and weight decay #301

Closed turian closed 1 year ago

turian commented 1 year ago

Describe the bug

In addition to solo.utils.misc.remove_bias_and_norm_from_weight_decay, we might consider also solo.utils.misc.ignore_dim1_parameters to ignore weight decay and lars_adaptation_filter them form.

Versions solo-learn main

Additional comments I see that solo.utils.misc.remove_bias_and_norm_from_weight_decay was added in #289. This is quite nice. (I can't tell from your patch, but bias and norm should also be excluded from the LARS adaptation. Is this or is this not the case? Your line 371 appears to do this but it's hard to grok from the comment and the small diff. See Facebook code here).

The key thing is that one might want to wish to exclude ALL ndim==1 parameters from LARS updates and weight decay. (FB vicreg code) There might be other norms, etc., for which this is a good idea. In the FB code they just had batchnorm so it didn't matter, but there might be other settings where this would be a good idea.

If this works, I might pester/help pytorch (https://github.com/pytorch/pytorch/issues/1402) and lightning (with whom I've been discussing this issue) to upstream this change.

vturrisi commented 1 year ago

The issue is that for some reason I noticed slightly lower performance with and without remove_bias_and_norm_from_weight_decay on a couple of methods that I tried. I'm not sure if the cause for this is just the combination of the other parameters or something else entirely. I think having them decoupled is beneficial.

Line 371 just excludes the parameters from the scheduler (only SimSiam does this).

turian commented 1 year ago

@vturrisi perhaps this could be a configurable option? (Or separate options for disabling weight decay on ndim 1 and disabling LARS updates on ndim 1)

Here's more evidence:

vturrisi commented 1 year ago

@turian I see. As it is now, we should already have these two options, no? Maybe just the naming that's confusing. Here we only update parameters that have ndim != 1 if exclude_bias_n_norm is True. About excluding from weight decay, I believe that I copied some old code from timm, but I think this makes more sense. Nonetheless, the exclude_bias_n_norm_wd should trigger this function to be executed.

Maybe the solution is just:

turian commented 1 year ago

Ah I see! Yes, I guess I found the names confusing. Maybe a docstring could help too, but the renames could be great.

vturrisi commented 1 year ago

I re-checked my function and found an issue that we were overwriting the weight decay. Gonna push a fix for that in the upcoming release. About renaming, I thought a bit more about that and I think it's a fair enough name, since the only 1d parameters that we have are normalization layers and biases.

turian commented 1 year ago

About renaming, I thought a bit more about that and I think it's a fair enough name, since the only 1d parameters that we have are normalization layers and biases.

That doesn't seem so forward thinking. For example, what if some SSL algorithm uses PReLU activations? If I'm not mistaken those are ndim 1 also and you probably don't want to weight decay or LARS adjust them.

turian commented 1 year ago

I re-checked my function and found an issue that we were overwriting the weight decay. Gonna push a fix for that in the upcoming release.

I'm having to do code review of the PR, I'm excited to work on ptl SSL code, given what I'm currently building

vturrisi commented 1 year ago

Indeed the name wouldn't be completely clear for that method, even though it would still incorporate those ndim=1 learnable parameters. Still, having a name as exclude_ndim1 might not be clear for some people.

vturrisi commented 1 year ago

We now have exclude_bias_n_norm_wd as an entry to optimizer and exclude_bias_n_norm as an entry of optimizer/kwargs for when lars is enabled. This addresses both cases and I think it's clear enough.