Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k stars 63 forks source link

CUDA Graphs region includes unnecessary "del" symbols resulting in an invalid Python program #728

Open IvanYashchuk opened 2 weeks ago

IvanYashchuk commented 2 weeks ago

🐛 Bug

First, apply the following patch:

diff --git a/thunder/executors/cudagraphex.py b/thunder/executors/cudagraphex.py
index 61e482e5..38bbf81f 100644
--- a/thunder/executors/cudagraphex.py
+++ b/thunder/executors/cudagraphex.py
@@ -155,6 +155,8 @@ class CUDAGraphExecutor(FusionExecutor):
         return fusion_bsym

     def can_fuse(self, bsym: BoundSymbol):
+        from thunder.executors.torchex import item
+
         curr_tracectx = get_tracectx()
         assert hasattr(curr_tracectx, "clear_collection_names")

@@ -180,6 +182,7 @@ class CUDAGraphExecutor(FusionExecutor):
             prims.PrimIDs.RETURN,
             # Data-dependent ops
             prims.PrimIDs.ITEM,
+            item.id,
         }

         if bsym.sym.id in do_not_fuse_sym_set:

Then run the following reproducer:

import thunder
import torch

def func(x):
    return x.sum().item()

a = torch.randn(3, 3, device="cuda")

jfunc = thunder.jit(func, use_cudagraphs=True)
jfunc(a) # Fails with UnboundLocalError: local variable 't0' referenced before assignment
File thunder.computation_4:10, in computation(x)
      6 @torch.no_grad()
      7 @no_autocast
      8 def computation(x):
      9   # x: "cuda:0 f32[3, 3]"
---> 10   [t0] = CUDAGraph0(x)
     11   f1 = Tensor.item(t0)  # f1: "float ?"
     12   return f1

File ~/dev/lightning-thunder/thunder/executors/cudagraphex.py:93, in CUDAGraphCallable.__call__(self, *args)
     89     static_inputs_mask = tuple(isinstance(arg, torch.nn.Parameter) for arg in args)
     91 args_descriptor = to_arg_descriptor(*args)
---> 93 graph, static_inputs, static_outputs = build_cuda_graph(self.fn, args_descriptor, static_inputs_mask)
     95 for static_input, arg in utils.safe_zip(static_inputs, args):
     96     if id(static_input) != id(arg) and isinstance(static_input, torch.Tensor) and isinstance(arg, torch.Tensor):

File ~/dev/lightning-thunder/thunder/executors/cudagraphex.py:62, in build_cuda_graph(fn, args_descriptor, static_args_mask)
     58     static_inputs = tuple(
     59         get_static_buffer(arg) if not is_static else arg for arg, is_static in zip(args, static_args_mask)
     60     )
     61     for _ in range(3):
---> 62         fn(*static_inputs)
     64 stream.synchronize()
     65 torch.cuda.current_stream().wait_stream(stream)

File ~/dev/pytorch/main/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/dev/pytorch/main/torch/amp/autocast_mode.py:28, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_autocast(*args, **kwargs):
     27     with autocast_instance:
---> 28         return func(*args, **kwargs)

File ~/dev/pytorch/main/torch/amp/autocast_mode.py:28, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     25 @functools.wraps(func)
     26 def decorate_autocast(*args, **kwargs):
     27     with autocast_instance:
---> 28         return func(*args, **kwargs)

File thunder.CUDAGraph0_fn_3:10, in CUDAGraph0_fn(***failed resolving arguments***)
      8 del x
      9 del t0
---> 10 return [t0]

UnboundLocalError: local variable 't0' referenced before assignment

A minimal example was asked for here https://github.com/Lightning-AI/lightning-thunder/pull/214#discussion_r1665520398.

tfogal commented 2 weeks ago

triage review:

nikitaved commented 2 weeks ago

@tfogal , which ops other than data-dependent ones are we talking about?

I left a comment about the need to investigate how to handle the islands around data-dep ops since, up until now, we did not have proper use cases for that. We can delete tensors, we just should not delete the ones coming from non-fused regions (aka data-dep ops).

tfogal commented 2 weeks ago

It's probably just data-dependent ops, but ping @mruberry who I copied that comment from ;-)