Open bytbox opened 3 years ago
I think it's probably related to the fact that we added __getattr__
method to variables which may cause issues when restoring from pickle.
Here is stackoverflow discussion related to __getattr__
and pickle: https://stackoverflow.com/questions/49380224/how-to-make-classes-with-getattr-pickable
I think the fix would be to change how value
property works in variables:
class TrainVar(BaseVar):
@property
def value(self) -> JaxArray:
return self.__dict__['_value'] # instead of self._value
and possibly similar changes for TrainRef
and StateVar
.
Would you be able to try these changes to see if they work or not?
This doesn't seem to work:
Traceback (most recent call last):
File "/home/scott/objax/tests/pickle.py", line 33, in test_on_linear
lin_ = pickle.loads(pickled)
File "/home/scott/objax/objax/variable.py", line 143, in __getattr__
return getattr(self.value, name)
File "/home/scott/objax/objax/variable.py", line 180, in value
return self.__dict__['_value']
KeyError: '_value'
And indeed, if I print out self.__dict__
the line before the return
in value()
, it's just empty.
However, I can fix the error by following that link a little more closely, and changing BaseVar
to throw an AttributeError
when _value
is not available. That passes all tests and appears to fix this error. (But, I'm not super familiar with how objax internals work, so maybe it's obvious that that's the wrong thing to do. Let me know!)
A pull request follows...
This is with objax-1.4.0, from PyPI. The problem does not occur with objax-1.3.1.
For example:
The error is coming from
loads
, notdumps
. The resulting stack trace ends with: