Open kshitij12345 opened 5 months ago
cc: @IvanYashchuk
Grabbing for investigation
In function is_constant_for_vjp
, following line incorrectly returns True
.
https://github.com/Lightning-AI/lightning-thunder/blob/257876624bf72392e979cbf016442d25322a0201/thunder/core/transforms.py#L3338
This happens because the isinstance lookaside in general_jit changes the check for TensorProxy
to res = issubclass(torch.Tensor, ucls)
. Hence the above line incorrectly returns True.
https://github.com/Lightning-AI/lightning-thunder/blob/257876624bf72392e979cbf016442d25322a0201/thunder/core/jit_ext.py#L708-L720
Putting a potential fix
diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 9ccce83..21f77b5 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -709,7 +709,7 @@ def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[ty
uobj = unwrap(obj)
ucls = unwrap(cls)
if isinstance(uobj, TensorProxy):
- res = issubclass(torch.Tensor, ucls)
+ res = issubclass(torch.Tensor, ucls) or isinstance(uobj, ucls)
else:
res = isinstance(uobj, ucls)
Leads to a different error:
File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 3662, in _vjp
result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
File "/home/kkalambarkar/lightning-thunder/thunder/core/interpreter.py", line 6179, in partial_call_impl
return partial_function.func(*(partial_function.args + args), **(partial_function.keywords | kwargs))
File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 3636, in vjp_call_metafunc
result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
TypeError: expected 1 argument, got 2
Note that it works with the deprecated
thunder.compile
.Output: