Closed WhaleCoded closed 1 year ago
Recreated on my machine with the following interactive session:
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from autoclip.torch import QuantileClip
>>> model = torch.nn.Linear(5, 5)
>>> optimizer = torch.optim.Adam(
... model.parameters(),
... lr = 0.05,
... weight_decay = 0.1,
... )
>>> optimizer = QuantileClip.as_optimizer(optimizer, 0.5)
>>> torch.save(optimizer, 'test_file.pth')
>>> optimizer = torch.load('test_file.pth')
>>> optimizer.lr
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/tanner/miniconda3/envs/branch/lib/python3.9/site-packages/autoclip/torch/clipper.py", line 235, in __getattr__
return getattr(self.optimizer, attr)
File "/home/tanner/miniconda3/envs/branch/lib/python3.9/site-packages/autoclip/torch/clipper.py", line 235, in __getattr__
return getattr(self.optimizer, attr)
File "/home/tanner/miniconda3/envs/branch/lib/python3.9/site-packages/autoclip/torch/clipper.py", line 235, in __getattr__
return getattr(self.optimizer, attr)
[Previous line repeated 996 more times]
RecursionError: maximum recursion depth exceeded
Looks like this problem will only rear it's head when the accessed attribute does not exist in optimizer. For whatever reason, what is normally a nice readable error like: AttributeError: 'Adam' object has no attribute 'lr'
becomes this ugly recursion problem.
As an example of the saving and loading working correctly:
Python 3.9.7 (default, Sep 16 2021, 13:09:58)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> from autoclip.torch import QuantileClip
>>> model = torch.nn.Linear(5, 5)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> optimizer = QuantileClip.as_optimizer(optimizer, 0.5)
>>> torch.save(optimizer, "test_file.pth")
>>> optimizer = torch.load("test_file.pth")
>>> print(optimizer.defaults)
{'lr': 0.001, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False}
>>>
For those who have future problems using torch.save
or pickle and find themselves on this thread, it is generally recommended to use the state_dict
pattern for saving and checkpointing. See the README.md
for more info.
When saving with torch.save() and loading with torch.load(). Autoclipper hits the max recursion depth.