Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.15k stars 77 forks source link

value_and_grad returns None gradients with thunder.jit #211

Open kshitij12345 opened 5 months ago

kshitij12345 commented 5 months ago

Note that it works with the deprecated thunder.compile.

import torch
from thunder.core.transforms import value_and_grad
import thunder

def model(x, w1):
    return x + w1

inp = torch.randn(1, 1)  # doesn't matter if requires_grad is True/False
w1 = torch.randn(1, 1)  # doesn't matter if requires_grad is True/False

print(thunder.compile(value_and_grad(model), disable_preprocessing=True)(inp, w1))
print(thunder.jit(value_and_grad(model))(inp, w1))

Output:

(tensor([[0.4349]]), (tensor([[1.]]), tensor([[1.]])))
(tensor([[0.4349]]), (None, None))
kshitij12345 commented 5 months ago

cc: @IvanYashchuk

kshitij12345 commented 5 months ago

Grabbing for investigation

kshitij12345 commented 5 months ago

In function is_constant_for_vjp, following line incorrectly returns True. https://github.com/Lightning-AI/lightning-thunder/blob/257876624bf72392e979cbf016442d25322a0201/thunder/core/transforms.py#L3338

This happens because the isinstance lookaside in general_jit changes the check for TensorProxy to res = issubclass(torch.Tensor, ucls). Hence the above line incorrectly returns True. https://github.com/Lightning-AI/lightning-thunder/blob/257876624bf72392e979cbf016442d25322a0201/thunder/core/jit_ext.py#L708-L720

Putting a potential fix

diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py
index 9ccce83..21f77b5 100644
--- a/thunder/core/jit_ext.py
+++ b/thunder/core/jit_ext.py
@@ -709,7 +709,7 @@ def _general_jit_isinstance_lookaside(obj: Any, cls: type | UnionType | tuple[ty
     uobj = unwrap(obj)
     ucls = unwrap(cls)
     if isinstance(uobj, TensorProxy):
-        res = issubclass(torch.Tensor, ucls)
+        res = issubclass(torch.Tensor, ucls) or isinstance(uobj, ucls)
     else:
         res = isinstance(uobj, ucls)

Leads to a different error:

File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 3662, in _vjp
    result, vjp_result = vjp_call(flat_args, cotangents, trace=trace)
  File "/home/kkalambarkar/lightning-thunder/thunder/core/interpreter.py", line 6179, in partial_call_impl
    return partial_function.func(*(partial_function.args + args), **(partial_function.keywords | kwargs))
  File "/home/kkalambarkar/lightning-thunder/thunder/core/transforms.py", line 3636, in vjp_call_metafunc
    result, env = augmented_forward_pass(*primals, trace=trace, **kwargs)
TypeError: expected 1 argument, got 2