kozistr / pytorch_optimizer

optimizer & lr scheduler & loss function collections in PyTorch
https://pytorch-optimizers.readthedocs.io/en/latest/
Apache License 2.0
242 stars 21 forks source link

[Fix] Implement better `wd_ban_list` handling #282

Closed Vectorrent closed 1 month ago

Vectorrent commented 1 month ago

Problem (Why?)

The wd_ban_list argument for get_optimizer_parameters() is somewhat misleading. When you look at it, you would expect any of the default arguments' name-formats to work correctly. However, that is not the case.

wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight')

From this list, only bias is "detected" and "banned" correctly. Neither LayerNorm.bias is detected, nor is LayerNorm.weight. Neither of these parameters have their weight_decay set to 0.

I even tested LayerNorm - and that doesn't work, either.

Solution (What/How?)

The reason this fails is that the wd_ban_list logic is only checking for the actual, fully-qualified parameter names; it is NOT checking for the class name of each nn.Module, as pytorch_optimizer's default arguments and tests would imply.

I implemented a more complete method for handling the wd_ban_list. Now, we check both for "true names", as well as for nn.Module names.

Notes

I've been using this patch in my own code for several weeks now; it seems to work great! Let me know if there is anything you would change.

Vectorrent commented 1 month ago

I just pushed a new commit, with a few fixes. However, there is one error I was not able to fix:

pytorch_optimizer/optimizer/utils.py:201:5: D212 [*] Multi-line docstring summary should start at the first line
    |
199 |       wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200 |   ) -> PARAMETERS:
201 |       r"""
    |  _____^
202 | |     Get optimizer parameters while filtering specified modules.
203 | |     :param model_or_parameter: Union[nn.Module, List]. model or parameters.
204 | |     :param weight_decay: float. weight_decay.
205 | |     :param wd_ban_list: List[str]. ban list not to set weight decay.
206 | |     :returns: PARAMETERS. new parameter list.
207 | |     """
    | |_______^ D212
208 |   
209 |       fully_qualified_names = []
    |
    = help: Remove whitespace after opening quotes

Found 3 errors.
[*] 2 fixable with the `--fix` option.
make: *** [Makefile:16: check] Error 1

If you run make format, it fixes this issue. But then, if you run make check, it fails. So, if I manually fix it, then make check will work - but make format will fail, now!

I'm not super familiar with make, so I don't really know what to do here.

kozistr commented 1 month ago

I just pushed a new commit, with a few fixes. However, there is one error I was not able to fix:

pytorch_optimizer/optimizer/utils.py:201:5: D212 [*] Multi-line docstring summary should start at the first line
    |
199 |       wd_ban_list: List[str] = ('bias', 'LayerNorm.bias', 'LayerNorm.weight'),
200 |   ) -> PARAMETERS:
201 |       r"""
    |  _____^
202 | |     Get optimizer parameters while filtering specified modules.
203 | |     :param model_or_parameter: Union[nn.Module, List]. model or parameters.
204 | |     :param weight_decay: float. weight_decay.
205 | |     :param wd_ban_list: List[str]. ban list not to set weight decay.
206 | |     :returns: PARAMETERS. new parameter list.
207 | |     """
    | |_______^ D212
208 |   
209 |       fully_qualified_names = []
    |
    = help: Remove whitespace after opening quotes

Found 3 errors.
[*] 2 fixable with the `--fix` option.
make: *** [Makefile:16: check] Error 1

If you run make format, it fixes this issue. But then, if you run make check, it fails. So, if I manually fix it, then make check will work - but make format will fail, now!

I'm not super familiar with make, so I don't really know what to do here.

it's okay. I can handle lint stuff.

anyway, thanks for the contributions!