Open nikitaved opened 5 months ago
The root cause seems to be in vjp
itself.
import thunder
import torch
def foo(x):
return thunder.torch.topk(x, k=2)
x = torch.ones(3, 3) * 2
co_x = torch.ones(3, 3)
outputs = torch.topk(x, k=2)
cotangents = tuple(torch.ones_like(x) for x in outputs)
vjp_foo = thunder.core.transforms.vjp(foo)
jfoo = thunder.compile(vjp_foo, disable_preprocessing=True)
# jfoo = thunder.jit(vjp_foo) # Doesn't work currently.
# Fails with
# File "/home/kkalambarkar/lightning-thunder/thunder/core/utils.py", line 1062, in <lambda>
# if arg_name not in map(lambda x: x.name, stop_proxies) and arg_name not in seen:
# AttributeError: 'NoneType' object has no attribute 'name'
jfoo(primals=(x,), cotangents=cotangents)
NOTE: Currently the test uses make_callable_legacy
(which uses thunder.compile
). We should probably wait till thunder.jit(vjp(fn))
is supported and then verify. (Related issue: https://github.com/Lightning-AI/lightning-thunder/issues/198)
🐛 Bug
As per title. To reproduce, one could uncomment these tests in these tests in https://github.com/Lightning-AI/lightning-thunder/pull/118 to get: