anthony-wang / CrabNet

Predict materials properties using only the composition information!
https://doi.org/10.1038/s41524-021-00545-1
MIT License
92 stars 28 forks source link

AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'. Did you mean: '_optimizer_step_code'? #36

Open pbenner opened 1 year ago

pbenner commented 1 year ago

The pytorch Optimizer class has changed with recent releases, which leads to the following error:

[...]
stepping every 16 training passes, cycling lr every 1 epochs
checkin at 2 epochs to match lr scheduler
Traceback (most recent call last):
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 141, in <module>
    run_cv(X, y, f'eval-{task}-{target}.txt', n_splits)
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 95, in run_cv
    model = train_model()
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/eval.py", line 62, in train_model
    model.fit(epochs=1000, losscurve=False)
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/model.py", line 228, in fit
    self.train()
  File "/home/pbenner/Source/pycoordinationnet-results/model_comparison/crabnet/model.py", line 140, in train
    self.optimizer.step()
  File "/home/pbenner/.local/opt/anaconda3/envs/crysfeat/lib/python3.10/site-packages/torch/optim/lr_scheduler.py", line 69, in wrapper
    return wrapped(*args, **kwargs)
  File "/home/pbenner/.local/opt/anaconda3/envs/crysfeat/lib/python3.10/site-packages/torch/optim/optimizer.py", line 271, in wrapper
    for pre_hook in chain(_global_optimizer_pre_hooks.values(), self._optimizer_step_pre_hooks.values()):
AttributeError: 'SWA' object has no attribute '_optimizer_step_pre_hooks'. Did you mean: '_optimizer_step_code'?

The following patch fixed the issue:


diff --git a/utils/optim.py b/utils/optim.py
index 33008dd..18224ea 100644
--- a/utils/optim.py
+++ b/utils/optim.py
@@ -1,6 +1,7 @@
-from collections import defaultdict
+from collections import defaultdict, OrderedDict
 from itertools import chain
 from torch.optim import Optimizer
+from typing import Callable, Dict
 import torch
 import warnings
 import numpy as np
@@ -116,6 +117,8 @@ class SWA(Optimizer):
         self.optimizer = optimizer

         self.defaults = self.optimizer.defaults
+        self._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()
+        self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()
         self.param_groups = self.optimizer.param_groups
         self.state = defaultdict(dict)
         self.opt_state = self.optimizer.state