airctic / icevision

An Agnostic Computer Vision Framework - Pluggable to any Training Library: Fastai, Pytorch-Lightning with more to come
https://airctic.github.io/icevision/
Apache License 2.0
846 stars 149 forks source link

Pytorch lightning examples ignore model.param_groups() which are not effectively passed to the optimizer #896

Open potipot opened 3 years ago

potipot commented 3 years ago

🐛 Bug

In tutorials and examples on the website we always show this method of initializing pytorch lightning optimizers:

class LightModel(model_type.lightning.ModelAdapter):
    def configure_optimizers(self):
        return SGD(self.parameters(), lr=1e-4)

So the problem is that we call self.parameters() instead of self.model.param_groups(). Effectively the optimizer is initialized with a single param group and misses out on important features of model training and parameter split.

This shouldn't be complicated to fix, however simply replacing these functions doesn't work - it has to do with the fact that fastai has its own Optimizer init logic

class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `cbs`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self, params, cbs, train_bn=True, **defaults):
        params = L(params)
        self.cbs,self.state,self.train_bn = L(cbs),defaultdict(dict),train_bn
        defaults = merge(*self.cbs.attrgot('defaults'), defaults)
        self.param_lists = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
        self.hypers = L({} for _ in range_of(self.param_lists))
        self.set_hypers(**defaults)
        self.frozen_idx = 0

where param groups being a List[List[Parameter]] is not correctly converted by regular torch optimizers resulting in an error:

raise TypeError("optimizer can only optimize Tensors, "
TypeError: optimizer can only optimize Tensors, but one of the params is list

Solution I'm thinking a potential solution to this problem would be creating a working example and updating the tutorials accordingly. In case of fastai that param_groups call is wrapped in the adapted_fastai_learner function and not directly exposed to a user.

potipot commented 3 years ago

We have model.param_groups return List of List, which cannot be passed directly to a regular torch Optimizer init cause it expects a list of dicts and fails upon conversion of L of L. I've come up with something of this sort:

    def configure_optimizers(self):
        # lets copy exactly yolo optimizer param group specification from:
        # https://github.com/ultralytics/yolov5/blob/2c073cd207bae1163b472c561d3fd31b1d2ba870/train.py#L148

        # in case of yolov5 we have 3 parameter groups (backbone, neck and head) which is implemented in:
        # https://github.com/airctic/icevision/blob/944b47c5694243ba3f3c8c11a6ef56f05fb111eb/icevision/models/ultralytics/yolov5/model.py#L87
        backbone, neck, head = self.model.param_groups()
        param_dict = [
            {"params": backbone},
            {"params": neck, "weight_decay": self.hparams.weight_decay},
            {"params": head},
        ]
        # here the learning rate will be specified uniformly for all parameter groups
        optimizer = Adam(
            param_dict,
            lr=self.hparams.learning_rate,
            betas=(self.hparams.momentum, 0.999),
        )
        return optimizer
FraPochetti commented 2 years ago

@potipot thoughts on this one?

potipot commented 2 years ago

this needs to be fixed in all examples that use pytorch lightning, keeping this one in the TODO

FraPochetti commented 2 years ago

It is true that PL support does not really shine in IceVision. We also have to support some sort of freezing (here) otherwise param_groups would still be pretty much irrelevant, as they would still be all trained (or not) and all with the same LR.

fastai is far superior atm.