Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

Make CudaGraph wrapping a transform (but call it explicitly) #635

Closed t-vi closed 2 days ago

t-vi commented 1 week ago

Currently when running

def fn(x, y):
    a = torch.relu(x * y)
    return a

jfn = thunder.jit(fn)
x, y = torch.randn(2, 5, 5, device="cuda")
jfn(x, y)

trace = thunder.last_traces(jfn)[-1]

I get

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, y):
  # x: "cuda:0 f32[5, 5]"
  # y: "cuda:0 f32[5, 5]"
  [a] = TorchCompile0(x, y)
    # result = ltorch.mul(x, y)  # result: "cuda:0 f32[5, 5]"
      # result = prims.mul(x, y)  # result: "cuda:0 f32[5, 5]"
    # a = ltorch.relu(result, False)  # a: "cuda:0 f32[5, 5]"
      # t1 = ltorch.gt(result, 0)  # t1: "cuda:0 b8[5, 5]"
        # _ = prims.convert_element_type(0, float)
        # t1 = prims.gt(result, 0.0)  # t1: "cuda:0 b8[5, 5]"
      # a = ltorch.where(t1, result, 0)  # a: "cuda:0 f32[5, 5]"
        # _ = prims.convert_element_type(0, float)
        # a = prims.where(t1, result, 0.0)  # a: "cuda:0 f32[5, 5]"
  del x, y

  # /tmp/ipykernel_982252/871691319.py:2:       a = torch.relu(x * y)
  return a

Ideally, we could transform the extrace for CudaGraphs instead of wrapping:

# Constructed by CUDA graph transform  (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, y):
  # x: "cuda:0 f32[5, 5]"
  # y: "cuda:0 f32[5, 5]"
  [a] = CudaGraph0(x, y)
    # [a] = TorchCompile0(x, y)
      # result = ltorch.mul(x, y)  # result: "cuda:0 f32[5, 5]"
        # result = prims.mul(x, y)  # result: "cuda:0 f32[5, 5]"
      # a = ltorch.relu(result, False)  # a: "cuda:0 f32[5, 5]"
        # t1 = ltorch.gt(result, 0)  # t1: "cuda:0 b8[5, 5]"
          # _ = prims.convert_element_type(0, float)
          # t1 = prims.gt(result, 0.0)  # t1: "cuda:0 b8[5, 5]"
        # a = ltorch.where(t1, result, 0)  # a: "cuda:0 f32[5, 5]"
          # _ = prims.convert_element_type(0, float)
          # a = prims.where(t1, result, 0.0)  # a: "cuda:0 f32[5, 5]"
  del x, y

  # /tmp/ipykernel_982252/871691319.py:2:       a = torch.relu(x * y)
  return a

This could be applied in the same way to both backward and forward (or a joint trace eventually).

My idea would be that __init__.py

https://github.com/Lightning-AI/lightning-thunder/blob/9f9dcafc9ba5b07652bbab91a602aec3c628c8d1/thunder/__init__.py#L623-L636

would be changed to

 if cd.use_cudagraphs: 
     computation_trc = cuda_graph_transform(computation_trc)
     computation_traces.append(computation_trc)
     if backward_fn is not None: 
        backward_trc = cuda_graph_transform(backward_trc)
         backward_traces.append(backward_trc)

 comp = computation_trc.python_callable() 

 if backward_trc is not None: 
     backward_fn = backward_trc.python_callable() 
 else: 
     backward_fn = None   

@nikitaved

mruberry commented 4 days ago

triage review: this sounds. great! If we're really going after CUDA graphs then a design review would be awesome