Closed turian closed 1 year ago
The issue is that for some reason I noticed slightly lower performance with and without remove_bias_and_norm_from_weight_decay
on a couple of methods that I tried. I'm not sure if the cause for this is just the combination of the other parameters or something else entirely. I think having them decoupled is beneficial.
Line 371 just excludes the parameters from the scheduler (only SimSiam does this).
@vturrisi perhaps this could be a configurable option? (Or separate options for disabling weight decay on ndim 1 and disabling LARS updates on ndim 1)
Here's more evidence:
@turian I see. As it is now, we should already have these two options, no? Maybe just the naming that's confusing.
Here we only update parameters that have ndim != 1
if exclude_bias_n_norm
is True
. About excluding from weight decay, I believe that I copied some old code from timm, but I think this makes more sense. Nonetheless, the exclude_bias_n_norm_wd
should trigger this function to be executed.
Maybe the solution is just:
exclude_bias_n_norm
for lars to something like exclude_ndim1
exclude_bias_n_norm_wd
and replacing our implementation with timm's new implementationAh I see! Yes, I guess I found the names confusing. Maybe a docstring could help too, but the renames could be great.
I re-checked my function and found an issue that we were overwriting the weight decay. Gonna push a fix for that in the upcoming release. About renaming, I thought a bit more about that and I think it's a fair enough name, since the only 1d parameters that we have are normalization layers and biases.
About renaming, I thought a bit more about that and I think it's a fair enough name, since the only 1d parameters that we have are normalization layers and biases.
That doesn't seem so forward thinking. For example, what if some SSL algorithm uses PReLU activations? If I'm not mistaken those are ndim 1 also and you probably don't want to weight decay or LARS adjust them.
I re-checked my function and found an issue that we were overwriting the weight decay. Gonna push a fix for that in the upcoming release.
I'm having to do code review of the PR, I'm excited to work on ptl SSL code, given what I'm currently building
Indeed the name wouldn't be completely clear for that method, even though it would still incorporate those ndim=1
learnable parameters. Still, having a name as exclude_ndim1
might not be clear for some people.
We now have exclude_bias_n_norm_wd
as an entry to optimizer
and exclude_bias_n_norm
as an entry of optimizer/kwargs
for when lars is enabled. This addresses both cases and I think it's clear enough.
Describe the bug
In addition to
solo.utils.misc.remove_bias_and_norm_from_weight_decay
, we might consider alsosolo.utils.misc.ignore_dim1_parameters
to ignore weight decay andlars_adaptation_filter
them form.Versions solo-learn main
Additional comments I see that
solo.utils.misc.remove_bias_and_norm_from_weight_decay
was added in #289. This is quite nice. (I can't tell from your patch, but bias and norm should also be excluded from the LARS adaptation. Is this or is this not the case? Your line 371 appears to do this but it's hard to grok from the comment and the small diff. See Facebook code here).The key thing is that one might want to wish to exclude ALL
ndim==1
parameters from LARS updates and weight decay. (FB vicreg code) There might be other norms, etc., for which this is a good idea. In the FB code they just had batchnorm so it didn't matter, but there might be other settings where this would be a good idea.If this works, I might pester/help pytorch (https://github.com/pytorch/pytorch/issues/1402) and lightning (with whom I've been discussing this issue) to upstream this change.