mit-ll-responsible-ai / hydra-zen

Create powerful Hydra applications without the yaml files and boilerplate code.
https://mit-ll-responsible-ai.github.io/hydra-zen/
MIT License
336 stars 15 forks source link

Support for pytorch lightning 'save_hyperparameters' #216

Closed rallen10 closed 2 years ago

rallen10 commented 2 years ago

When I try to add save_hyperparameters() to my Lightning module init, I get the following error when trying to run hydra-zen

Traceback (most recent call last):
  File "src/pyrmm/modelgen/dubins.py", line 315, in task_function
    trainer.fit(obj.pl_module, obj.data_module)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 740, in fit
    self._call_and_handle_interrupt(
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 685, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1188, in _run
    self._pre_dispatch()
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1224, in _pre_dispatch
    self._log_hyperparams()
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1261, in _log_hyperparams
    self.logger.save()
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py", line 50, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/loggers/tensorboard.py", line 264, in save
    save_hparams_to_yaml(hparams_file, self.hparams)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 389, in save_hparams_to_yaml
    yaml.dump(v)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/__init__.py", line 253, in dump
    return dump_all([data], stream, Dumper=Dumper, **kwds)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/__init__.py", line 241, in dump_all
    dumper.represent(data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 27, in represent
    node = self.represent_data(data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 356, in represent_object
    return self.represent_mapping(tag+function_name, value)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 286, in represent_tuple
    return self.represent_sequence('tag:yaml.org,2002:python/tuple', data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 207, in represent_dict
    return self.represent_mapping('tag:yaml.org,2002:map', data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 342, in represent_object
    return self.represent_mapping(
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 48, in represent_data
    node = self.yaml_representers[data_types[0]](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 199, in represent_list
    return self.represent_sequence('tag:yaml.org,2002:seq', data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 92, in represent_sequence
    node_item = self.represent_data(item)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 342, in represent_object
    return self.represent_mapping(
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 342, in represent_object
    return self.represent_mapping(
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 118, in represent_mapping
    node_value = self.represent_data(item_value)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 52, in represent_data
    node = self.yaml_multi_representers[data_type](self, data)
  File "/home/ross/miniconda3/envs/pyrmm/lib/python3.8/site-packages/yaml/representer.py", line 331, in represent_object
    if function.__name__ == '__newobj__':
AttributeError: 'str' object has no attribute '__name__'

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace

If I remove the save_hyperparameters() call, everything works fine

jgbos commented 2 years ago

Hi @rallen10, I'm guessing this relates to variables not supported by PL's save_hyperparameters. Can you provide the config for both pl_module and data_module?

rsokl commented 2 years ago

This doesn't look like it is a hydra-zen or Hydra issue -- I expect that if you initialize your PL module by hand that you will get the exact same error.

If you followed our PL How-To guide then what is likely going on is that you passed your optimizer into the __init__ of your LightningModule. save_hyperparameters() will attempt to save everything that was passed to you __init__, but it can only handle primitive data types (like int and bool).

I believe you can circumvent this by specifying the specific names of the parameters you want to save; e.g. self.save_hyperparameters("layer_1_dim", "learning_rate") . If you exclude the names of non-primitive fields, then this error should go away.

However, given that you are using hydra-zen, there isn't really a need to use save_hyperparameters() anymore 😄 (Edit: @jgbos pointed out that save_hyperparameters() can have some utility for logging parameters to tensorboad). As you can see here you can track all of your hyperparameters from your Hydra configs and load your LightningModule from the yaml that gets serialized whenever you launch your job.

rallen10 commented 2 years ago

Yeah after I posted this, I realized that my __init__() args is a torch.nn.Module object and the optimizer, neither are primitive types.

This identifies my problem, so I will close this issue.