KeyError in save_hyperparameters while using in a subclass #18405

Open vpozdnyakov opened 1 year ago

vpozdnyakov commented 1 year ago

Bug description

KeyError occurs when I try to save hyperparameters in a subclass which initializes LightningModule.

What version are you seeing the problem on?


How to reproduce the bug

from pytorch_lightning import LightningModule

class Base:
    def fit(self):

class LightningModel(LightningModule):
    def __init__(self, hidden_dim):

class Model(Base):
    def fit(self):
        self.model = LightningModel(hidden_dim=2)        

model = Model()

Error messages and logs

KeyError                                  Traceback (most recent call last)
[<ipython-input-3-d10b384b9dc7>](https://localhost:8080/#) in <cell line: 18>()
     17 model = Model()
---> 18

7 frames
[<ipython-input-3-d10b384b9dc7>](https://localhost:8080/#) in fit(self)
     13     def fit(self):
     14         super().fit()
---> 15         self.model = LightningModel(hidden_dim=2)
     17 model = Model()

[<ipython-input-3-d10b384b9dc7>](https://localhost:8080/#) in __init__(self, hidden_dim)
      8     def __init__(self, hidden_dim):
      9         super().__init__()
---> 10         self.save_hyperparameters()
     12 class Model(Base):

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/core/mixins/](https://localhost:8080/#) in save_hyperparameters(self, ignore, frame, logger, *args)
    109             if current_frame:
    110                 frame = current_frame.f_back
--> 111         save_hyperparameters(self, *args, ignore=ignore, frame=frame)
    113     def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/](https://localhost:8080/#) in save_hyperparameters(obj, ignore, frame, *args)
    162         from pytorch_lightning.core.mixins import HyperparametersMixin
--> 164         for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
    165             init_args.update(local_args)

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/](https://localhost:8080/#) in collect_init_args(frame, path_args, inside, classes)
    132         # recursive update
    133         path_args.append(local_args)
--> 134         return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
    135     if not inside:
    136         return collect_init_args(frame.f_back, path_args, inside=False, classes=classes)

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/](https://localhost:8080/#) in collect_init_args(frame, path_args, inside, classes)
    128         return path_args
--> 130     local_self, local_args = _get_init_args(frame)
    131     if "__class__" in local_vars and (not classes or isinstance(local_self, classes)):
    132         # recursive update

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/](https://localhost:8080/#) in _get_init_args(frame)
     94     exclude_argnames = (*filtered_vars, "__class__", "frame", "frame_args")
     95     # only collect variables that appear in the signature
---> 96     local_args = {k: local_vars[k] for k in init_parameters}
     97     # kwargs_var might be None => raised an error by mypy
     98     if kwargs_var:

[/usr/local/lib/python3.10/dist-packages/pytorch_lightning/utilities/](https://localhost:8080/#) in <dictcomp>(.0)
     94     exclude_argnames = (*filtered_vars, "__class__", "frame", "frame_args")
     95     # only collect variables that appear in the signature
---> 96     local_args = {k: local_vars[k] for k in init_parameters}
     97     # kwargs_var might be None => raised an error by mypy
     98     if kwargs_var:

KeyError: 'args'


More info

It works if I change the name of a subclass function to __init__.

amitkparekh commented 11 months ago

I'm getting the same error. However, mine is because I am trying to use Typer/Click as the entrypoint to my application.

I have no idea how to easily solve this apart from not saving hyperparameters at all or not using Typer/Click.

I was able to patch it with:

local_args = {k: local_vars[k] for k in init_parameters if k in local_vars}

I'm not sure if there are any wider ramifications of this approach, and I am not entirely sure what this is doing, but I'd be happy to submit a PR with some guidance?

d-a-bunin commented 11 months ago

I had a similar issue and found a solution. In your case you could try to create a function like:

def create_model():
    return LightningModel(hidden_dim=2)

and use it in method.

jamesdeeel commented 5 months ago

I also found this issue too. I think the crux of it is that if you call super() at any point during a method that isnt __init__ then __class__ is added to the local variables. When this happens then the recursive arg parser

assumes that Model is a subclass of LightningModel and so it should pull out the initialisation variables from Model.__init__(...) and add them to the dict of variables to be saved. The problem here is that we aren't in Model.__init__(...) so the initialisation variables are not present in the local variables.

I think this

I'm getting the same error. However, mine is because I am trying to use Typer/Click as the entrypoint to my application.

I have no idea how to easily solve this apart from not saving hyperparameters at all or not using Typer/Click.

I was able to patch it with:

local_args = {k: local_vars[k] for k in init_parameters if k in local_vars}

I'm not sure if there are any wider ramifications of this approach, and I am not entirely sure what this is doing, but I'd be happy to submit a PR with some guidance?

is the best approach, but alternatively there could be a check to make sure that we are in the init method before calling local args

if "__class__" in local_vars and (not classes or isinstance(local_self, classes)) and frame.f_back.f_code.co_name == "__init__":