google / objax

Apache License 2.0
768 stars 77 forks source link

RecursionError when attempting to unpickle objax objects #221

Open bytbox opened 3 years ago

bytbox commented 3 years ago

This is with objax-1.4.0, from PyPI. The problem does not occur with objax-1.3.1.

For example:

(env) theseus:~/mlqm/de$ python
Python 3.9.2 (default, Feb 20 2021, 18:40:11)
[GCC 10.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pickle, objax
>>> try:
...     pickle.loads(pickle.dumps(objax.nn.Linear(3,1)))
... except RecursionError:
...     print('Recursion error')
...
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Recursion error

The error is coming from loads, not dumps. The resulting stack trace ends with:

  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
    return getattr(self.value, name)
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
    return self._value
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
    return getattr(self.value, name)
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
    return self._value
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
    return getattr(self.value, name)
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
    return self._value
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 143, in __getattr__
    return getattr(self.value, name)
  File "/home/scott/mlqm/env/lib/python3.9/site-packages/objax/variable.py", line 179, in value
    return self._value
RecursionError: maximum recursion depth exceeded
AlexeyKurakin commented 3 years ago

I think it's probably related to the fact that we added __getattr__ method to variables which may cause issues when restoring from pickle. Here is stackoverflow discussion related to __getattr__ and pickle: https://stackoverflow.com/questions/49380224/how-to-make-classes-with-getattr-pickable

I think the fix would be to change how value property works in variables:


class TrainVar(BaseVar):

    @property
    def value(self) -> JaxArray:
        return self.__dict__['_value']  # instead of self._value

and possibly similar changes for TrainRef and StateVar.

Would you be able to try these changes to see if they work or not?

bytbox commented 3 years ago

This doesn't seem to work:

Traceback (most recent call last):
  File "/home/scott/objax/tests/pickle.py", line 33, in test_on_linear
    lin_ = pickle.loads(pickled)
  File "/home/scott/objax/objax/variable.py", line 143, in __getattr__
    return getattr(self.value, name)
  File "/home/scott/objax/objax/variable.py", line 180, in value
    return self.__dict__['_value']
KeyError: '_value'

And indeed, if I print out self.__dict__ the line before the return in value(), it's just empty.

However, I can fix the error by following that link a little more closely, and changing BaseVar to throw an AttributeError when _value is not available. That passes all tests and appears to fix this error. (But, I'm not super familiar with how objax internals work, so maybe it's obvious that that's the wrong thing to do. Let me know!)

A pull request follows...