Open albertz opened 1 month ago
I got this now a second time (CI log). It's occurs 10% of the cases (very approximately).
I assume the Tensor.__del__
handler maybe runs very late and calls to PyTorch API which is not really expected anymore at that point.
I assume the
Tensor.__del__
handler maybe runs very late and calls to PyTorch API which is not really expected anymore at that point.
I just pushed sth which should check for this. So let's see if this occurs again.
It also happened afterwards (in 6e2ce015f209036, CI log). Actually much more often now, seems to be 100% of the cases?
I can reproduce the crash locally.
(gdb) bt
#0 0x00007ffff7c86054 in visit_decref () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#1 0x00007ffff7c3baa7 in _PyObject_VisitInstanceAttributes () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#2 0x00007ffff7c4687c in subtype_traverse () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#3 0x00007ffff7cf9d15 in deduce_unreachable () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#4 0x00007ffff7cf9c4c in gc_collect_main () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#5 0x00007ffff7cf925c in gc_collect_with_callback () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#6 0x00007ffff7cf9ef2 in PyGC_Collect () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#7 0x00007ffff7cee923 in Py_FinalizeEx () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#8 0x00007ffff7cf8d40 in Py_RunMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#9 0x00007ffff7cf8ab9 in Py_BytesMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#10 0x00007ffff778f1b7 in __libc_start_call_main (main=main@entry=0x401040 <main>, argc=argc@entry=4, argv=argv@entry=0x7fffffffda28) at ../sysdeps/nptl/libc_start_call_main.h:58
#11 0x00007ffff778f26c in __libc_start_main_impl (main=0x401040 <main>, argc=4, argv=0x7fffffffda28, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>,
stack_end=0x7fffffffda18) at ../csu/libc-start.c:392
#12 0x0000000000401071 in _start () at ../sysdeps/x86_64/start.S:115
I was playing around with iterating through all alive objects at the end, and that also triggers the crash.
Sth like this:
print("**** remaining objects:")
import gc
for obj in gc.get_objects():
if type(obj) in {tuple, list, dict}:
continue
print(type(obj), obj)
Crash:
0x00007ffff7c86054 in visit_decref () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
(gdb) bt
#0 0x00007ffff7c86054 in visit_decref () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#1 0x00007ffff7c3baa7 in _PyObject_VisitInstanceAttributes () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#2 0x00007ffff7c4687c in subtype_traverse () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#3 0x00007ffff7cf9d15 in deduce_unreachable () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#4 0x00007ffff7cf9457 in gc_collect_main () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#5 0x00007ffff7cf925c in gc_collect_with_callback () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#6 0x00007ffff7c85ddf in _PyObject_GC_Link () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#7 0x00007ffff7c85d12 in _PyObject_GC_New () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#8 0x00007ffff7c43771 in tuple_iter () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#9 0x00007ffff7c194e9 in PyObject_GetIter () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#10 0x00007ffff7c63c5c in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#11 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#12 0x00007ffff7c6584a in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#13 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#14 0x00007ffff7c1fbba in PyObject_CallOneArg () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#15 0x00007ffff7c3e3e5 in _PyObject_GenericGetAttrWithDict () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#16 0x00007ffff7c3dcae in PyObject_GetAttr () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#17 0x00007ffff7c62947 in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#18 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#19 0x00007ffff7c46fe0 in vectorcall_method () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#20 0x00007ffff7cc0fb6 in slot_tp_str () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#21 0x00007ffff7c3eac7 in PyObject_Str () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#22 0x00007ffff7cad4c4 in PyFile_WriteObject () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#23 0x00007ffff7cdabaf in builtin_print () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#24 0x00007ffff7c6502f in _PyEval_EvalFrameDefault () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#25 0x00007ffff7c60cf2 in _PyEval_Vector () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#26 0x00007ffff7cdbec6 in PyEval_EvalCode () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#27 0x00007ffff7cf0884 in run_eval_code_obj () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#28 0x00007ffff7cf0806 in run_mod () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#29 0x00007ffff7cf0ff1 in pyrun_file () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#30 0x00007ffff7cf0c6b in _PyRun_SimpleFileObject () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#31 0x00007ffff7cf0a83 in _PyRun_AnyFileObject () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#32 0x00007ffff7cf8e8c in Py_RunMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
#33 0x00007ffff7cf8ab9 in Py_BytesMain () from /work/tools/users/zeyer/linuxbrew/lib/libpython3.11.so.1.0
With python3-dbg some more:
Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x000055555567fb51 in _PyObject_IS_GC (obj=<unknown at remote 0x7fff08b03ac0>) at ../Include/internal/pycore_object.h:166
166 ../Include/internal/pycore_object.h: No such file or directory.
(gdb) bt
#0 0x000055555567fb51 in _PyObject_IS_GC (obj=<unknown at remote 0x7fff08b03ac0>) at ../Include/internal/pycore_object.h:166
#1 visit_decref (parent=<optimized out>, op=<unknown at remote 0x7fff08b03ac0>) at ../Modules/gcmodule.c:456
#2 dict_traverse (op={'pack_hook': <unknown at remote 0x7fff08b03ac0>, 'unpack_hook': <function at remote 0x7fff0aa0c670>},
visit=<optimized out>, arg=<optimized out>) at ../Objects/dictobject.c:3250
#3 0x000055555567f3e5 in subtract_refs (containers=<optimized out>) at ../Modules/gcmodule.c:482
#4 deduce_unreachable (base=base@entry=0x555555b3dfa0, unreachable=unreachable@entry=0x7fffffffd410) at ../Modules/gcmodule.c:1105
#5 0x000055555567e94c in gc_collect_main (tstate=0x555555b59b90, generation=2, n_collected=0x7fffffffd4d8,
n_uncollectable=0x7fffffffd4d0, nofail=0) at ../Modules/gcmodule.c:1239
#6 0x0000555555785e80 in gc_collect_with_callback (tstate=0x555555b59b90, generation=2) at ../Modules/gcmodule.c:1413
#7 0x00005555557b74de in PyGC_Collect () at ../Modules/gcmodule.c:2099
#8 0x00005555557b4ef0 in Py_FinalizeEx () at ../Python/pylifecycle.c:1781
#9 0x00005555557a6313 in Py_RunMain () at ../Modules/main.c:668
#10 0x000055555577ca3d in Py_BytesMain (argc=<optimized out>, argv=<optimized out>) at ../Modules/main.c:720
#11 0x00007ffff7c29d90 in __libc_start_call_main (main=main@entry=0x55555577ca00 <main>, argc=argc@entry=2,
argv=argv@entry=0x7fffffffd7b8) at ../sysdeps/nptl/libc_start_call_main.h:58
#12 0x00007ffff7c29e40 in __libc_start_main_impl (main=0x55555577ca00 <main>, argc=2, argv=0x7fffffffd7b8, init=<optimized out>,
fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7fffffffd7a8) at ../csu/libc-start.c:392
#13 0x000055555577c935 in _start ()
With:
print("**** remaining objects:")
import gc
for obj in gc.get_objects():
print("0x%x" % id(obj), type(obj), obj)
print("**** done.")
Another variant of the crash:
Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08be3d80>) at ../Objects/object.c:422
422 ../Objects/object.c: No such file or directory.
(gdb) bt
#0 0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08be3d80>) at ../Objects/object.c:422
#1 0x00005555557ca2c0 in dict_repr (mp=0x7fff08d92300) at ../Objects/dictobject.c:2148
#2 0x00005555556c4f4d in object_str (
self={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>})
at ../Objects/typeobject.c:4550
#3 PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>})
at ../Objects/object.c:499
#4 PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>})
at ../Objects/object.c:462
#5 0x000055555578f24d in PyFile_WriteObject (
v={'pack_hook': <unknown at remote 0x7fff08be3d80>, 'unpack_hook': <function at remote 0x7fff0a9e8f70>}, f=<optimized out>,
flags=<optimized out>) at ../Objects/fileobject.c:132
#6 0x000055555578e8f2 in builtin_print (self=<optimized out>, args=0x7ffff7529db0, nargs=3, kwnames=<optimized out>)
at ../Python/bltinmodule.c:2003
#7 0x00005555556a22eb in cfunction_vectorcall_FASTCALL_KEYWORDS (
func=<built-in method print of module object at remote 0x7ffff7594950>, args=0x7ffff7529db0, nargsf=<optimized out>, kwnames=0x0)
at ../Objects/methodobject.c:446
#8 0x0000555555697827 in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x7ffff7529db0,
callable=<built-in method print of module object at remote 0x7ffff7594950>, tstate=0x555555b59b90)
at ../Include/cpython/abstract.h:114
#9 PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7ffff7529db0,
callable=<built-in method print of module object at remote 0x7ffff7594950>) at ../Include/cpython/abstract.h:123
#10 call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, trace_info=0x7fffffffd300,
tstate=<optimized out>) at ../Python/ceval.c:5893
#11 _PyEval_EvalFrameDefault (tstate=<optimized out>, f=<optimized out>, throwflag=<optimized out>) at ../Python/ceval.c:4213
#12 0x0000555555693f96 in _PyEval_EvalFrame (throwflag=0,
f=Frame 0x7ffff7529c40, for file /home/az/Programmierung/returnn/tests/test_torch_util.py, line 354, in <module> (),
tstate=0x555555b59b90) at ../Include/internal/pycore_ceval.h:46
#13 _PyEval_Vector (tstate=0x555555b59b90, con=<optimized out>, locals=<optimized out>, args=<optimized out>,
argcount=<optimized out>, kwnames=<optimized out>) at ../Python/ceval.c:5067
#14 0x0000555555789c66 in PyEval_EvalCode (co=<code at remote 0x7ffff74a6d90>,
globals={'__name__': '__main__', '__doc__': '\nTest :mod:`returnn.torch.util`.\n', '__package__': None, '__loader__': <SourceFileLoader(name='__main__', path='/home/az/Programmierung/returnn/tests/test_torch_util.py') at remote 0x7ffff73dd9f0>, '__spec__': None, '__annotations__': {}, '__builtins__': <module at remote 0x7ffff7594950>, '__file__': '/home/az/Programmierung/returnn/tests/test_torch_util.py', '__cached__': None, 'annotations': <_Feature(optional=(3, 7, 0, 'beta', 1), mandatory=(3, 11, 0, 'alpha', 0), compiler_flag=16777216) at remote 0x7ffff731ddb0>, '_setup_test_env': <module at remote 0x7ffff7315f80>, 'os': <module at remote 0x7ffff7404e00>, 'sys': <module at remote 0x7ffff7582390>, 'unittest': <module at remote 0x7fffb95e19e0>, 'torch': <module at remote 0x7fffb95e1530>, 'better_exchook': <module at remote 0x7ffff739eca0>, 'gradient_checkpoint_scope': <type at remote 0x5555562d9ee0>, 'test_gradient_checkpoint_scope': <function at remote 0x7ffff74b56c0>, 'test_gradient_checkpoint_scope_twice': <functio...(truncated),
locals=<optimized out>) at ../Python/ceval.c:1134
Note, this object you see here in object_str
, that looks very much like the __dict__
of a saved_tensors_hooks
instance, which has pack_hook
and unpack_hook
.
Ok, I added this print
in gradient_checkpoint_scope.__init__
to print the address of the pack_hook
method:
def __init__(self):
self.record_graph_scope = _RecordGraph()
self.record_graph_scope.graph.gradient_checkpoint_scope_backref = self
# Note: saved_tensors_hooks is thread local.
self.saved_tensors_hooks_scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
print("*** pack hook: 0x%x" % id(self.saved_tensors_hooks_scope.pack_hook))
Then I get this at the end:
Executing: test_gradient_checkpoint_scope_twice
*** pack hook: 0x7fff0935a080
*** pack hook: 0x7fff08c07c80
*** _pack_hook
*** _pack_hook
*** _unpack_hook
*** _unpack_hook
*** _custom_saved_tensors_hooks_exit
*** pack hook: 0x7fff08be1240
*** _custom_saved_tensors_hooks_exit
*** pack hook: 0x7fff08be5200
*** _pack_hook
*** _pack_hook
*** _unpack_hook
*** _unpack_hook
*** _custom_saved_tensors_hooks_exit
----------------------------------------
Finished all tests.
**** remaining objects:
...
0x7fff08bc4e80 <class 'torch.autograd.graph.saved_tensors_hooks'> <torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bc4e80>
Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08c07c80>) at ../Objects/object.c:422
422 ../Objects/object.c: No such file or directory.
(gdb) bt
#0 0x00005555556c0ff6 in PyObject_Repr (v=<unknown at remote 0x7fff08c07c80>) at ../Objects/object.c:422
#1 0x00005555557ca2c0 in dict_repr (mp=0x7fff08c07240) at ../Objects/dictobject.c:2148
#2 0x00005555556c4f4d in object_str (
self={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>})
at ../Objects/typeobject.c:4550
#3 PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>})
at ../Objects/object.c:499
#4 PyObject_Str (v={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>})
at ../Objects/object.c:462
#5 0x000055555578f24d in PyFile_WriteObject (
v={'pack_hook': <unknown at remote 0x7fff08c07c80>, 'unpack_hook': <function at remote 0x7fff0a9e1000>}, f=<optimized out>,
flags=<optimized out>) at ../Objects/fileobject.c:132
#6 0x000055555578e8f2 in builtin_print (self=<optimized out>, args=0x7ffff7529db0, nargs=3, kwnames=<optimized out>)
at ../Python/bltinmodule.c:2003
...
So I guess we have already freed the method but we are still trying to access it here.
Could this be an error on PyTorch regarding refcounting on the pack_hook
?
Added some debug code:
def _custom_saved_tensors_hooks_exit(
self: torch.autograd.graph.saved_tensors_hooks, exc_type=None, exc_val=None, exc_tb=None
):
print(f"*** _custom_saved_tensors_hooks_exit, stack {_custom_saved_tensors_hooks_tls_ctx.stack}")
f = sys._getframe()
while f:
co = f.f_code
print("-", co.co_name, co.co_filename, f.f_lineno)
f = f.f_back
...
Then:
**** iter 0
*** pack hook: 0x7fff091c2a00
*** gradient_checkpoint_scope.__enter__
*** pack hook: 0x7fff08b3c640
*** gradient_checkpoint_scope.__enter__
*** _custom_saved_tensors_hooks_enter
*** _pack_hook
*** _pack_hook
[New Thread 0x7fff08aff640 (LWP 232551)]
[New Thread 0x7ffefcfde640 (LWP 232555)]
*** _unpack_hook
*** _unpack_hook
*** exit_saved_tensors_hooks_scope __exit__ now, pack_hook: 0x7fff08b3c640
*** _custom_saved_tensors_hooks_exit, stack [<torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bdecb0>, <torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bdf190>]
- _custom_saved_tensors_hooks_exit /home/az/Programmierung/returnn/tests/test_torch_util.py 641
- exit_saved_tensors_hooks_scope /home/az/Programmierung/returnn/tests/test_torch_util.py 296
- _maybe_exit_saved_tensors_hooks_scope /home/az/Programmierung/returnn/tests/test_torch_util.py 270
- _unpack_hook /home/az/Programmierung/returnn/tests/test_torch_util.py 315
- backward /home/az/.local/lib/python3.10/site-packages/torch/autograd/__init__.py 266
- backward /home/az/.local/lib/python3.10/site-packages/torch/_tensor.py 522
- demo_run /home/az/Programmierung/returnn/tests/test_torch_util.py 722
- test_saved_tensors_hooks_gc_segfault /home/az/Programmierung/returnn/tests/test_torch_util.py 730
- <module> /home/az/Programmierung/returnn/tests/test_torch_util.py 880
*** _custom_saved_tensors_hooks_exit: exit now, scope <torch.autograd.graph.saved_tensors_hooks object at 0x7fff08bdf190>, pack_hook 0x7fff08b3c640
...
**** iter 4
Thread 1 "python3.10" received signal SIGSEGV, Segmentation fault.
0x000055555567fb51 in _PyObject_IS_GC (obj=<unknown at remote 0x7fff08b3c640>) at ../Include/internal/pycore_object.h:166
166 ../Include/internal/pycore_object.h: No such file or directory.
So, maybe the problem is that we call saved_tensors_hooks.__exit__
inside the unpack hook?
I have a standalone test case:
def test_saved_tensors_hooks_gc_segfault2():
# https://github.com/rwth-i6/returnn/issues/1581
shape = (101, 103)
for i in range(10):
v1 = torch.nn.Parameter(torch.randn(shape))
v2 = torch.nn.Parameter(torch.randn(shape))
class _Handler:
def __init__(self, exit_in_unpack: bool = False):
self.scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
self.exit_in_unpack = exit_in_unpack
self.exited = False
def _pack_hook(self, x):
print(f"*** _pack_hook {self}")
return self, x
@staticmethod
def _unpack_hook(x):
self, x = x
print(f"*** _unpack_hook {self}")
if self.exit_in_unpack and not self.exited:
self.exited = True
self.scope.__exit__()
return x
handler1 = _Handler(exit_in_unpack=False)
handler1.scope.__enter__()
v1_ = v1 + torch.randn(shape)
handler2 = _Handler(exit_in_unpack=True)
handler2.scope.__enter__()
v2_ = v2 + torch.randn(shape)
x = v1_ * v2_
x.sum().backward()
del x
handler1.scope.__exit__()
I'm trying to simplify this now further.
Slightly different version:
def test_saved_tensors_hooks_gc_segfault2():
# https://github.com/rwth-i6/returnn/issues/1581
shape = (101, 103)
for i in range(10):
print("**** iter", i)
v = torch.nn.Parameter(torch.randn(shape))
class _Handler:
def __init__(self):
self.scope = torch.autograd.graph.saved_tensors_hooks(self._pack_hook, self._unpack_hook)
self.scope.__enter__()
self.exited = False
def _pack_hook(self, x):
print(f"*** _pack_hook {self}")
return x
def _unpack_hook(self, x):
print(f"*** _unpack_hook {self}")
if not self.exited:
self.exited = True
self.scope.__exit__()
return x
with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
handler = _Handler() # keep ref... # noqa
x = v * torch.randn(shape)
x.sum().backward()
I reported that upstream: https://github.com/pytorch/pytorch/issues/130734
I pushed a workaround now. See _can_exit_saved_tensors_hooks_inside_hooks
. If possible, I would like to extend this logic later. But let's wait for the response in https://github.com/pytorch/pytorch/issues/130734.
Actually let's keep this open until we got some response, and then wait until we can update _can_exit_saved_tensors_hooks_inside_hooks
.
Also note, the current solution is maybe not so optimal. The current potential ways that we would exit the torch.autograd.graph.saved_tensors_hooks
:
gradient_checkpoint_scope.__exit__
. But likely not, as there are likely refs to the registered tensors.gradient_checkpoint_scope.__del__
if in the right thread. But likely not, as there are likely still refs to the registered tensors.Tensor.__del__
if in the right thread.torch.autograd.graph.saved_tensors_hooks.__enter__
or torch.autograd.graph.saved_tensors_hooks.__exit__
._can_exit_saved_tensors_hooks_inside_hooks
.So, this means, in practice, with the current _can_exit_saved_tensors_hooks_inside_hooks
check, the only real realistic way that it gets cleaned up is via the next saved_tensors_hooks.__enter__
or Tensor.__del__
. Tensor.__del__
would be fine, but we cannot guarantee that this will be in the right thread.
The _GraphTensor
s are cleaned up independent of that, so the only problem here is the additional overhead we get because of a few pack/unpack hooks which don't do anything. Some solution to https://github.com/pytorch/pytorch/issues/129867 would allow us to reduce this (and also simplify the whole logic).
I just saw this in the CI (at commit d5b954b8f6e4c84ec2c289733590e1bf4154ba8b):
So tests ran through but at the exit, we got some segmentation fault. Maybe the gradient scope was cleaned up at that late point?