Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

jit: support dataclass as an output for jitted function #632

Closed kshitij12345 closed 4 days ago

kshitij12345 commented 1 week ago

Fixes: #623

Example Snippet:

from dataclasses import dataclass
import torch
import thunder

@dataclass
class MyContainer:
    res: torch.Tensor
    num: int

def fn(x):
    return MyContainer(x, 1)

jfn = thunder.jit(fn)

x = torch.ones(3)

print(jfn(x))
print(jfn(x + 1))
print(thunder.last_traces(jfn)[-1])

Output

MyContainer(res=tensor([1., 1., 1.]), num=1)
MyContainer(res=tensor([2., 2., 2.]), num=1)
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x):
  # x: "cpu f32[3]"
  return __main___MyContainer(res=x, num=1)

To get this working, we require

  1. Support for representing the dataclass in Trace.
  2. Support flattening the dataclass so that we can see if it is storing any proxy - this is required when we are inspecting the flat args, kwargs and output of a BoundSymbol.

Main changes -

  1. codeutils.py - Changes in this file related to how the dataclass instace will look and be represented in the trace. In this PR, the name of the class is class.__module__ + class.__qualname__ to avoid collision if there are two classes with same name from different module (can be seen in above snippet).
  2. pytree.py - We provide tree_flatten_with_dataclass which also flattens the dataclass so that we can see all the proxies that could be contained in a dataclass instance.
  3. symbol.py - We update functions in symbol.py to use tree_flatten_with_dataclass instead of tree_flatten which keeps the dataclass instances opaque.
t-vi commented 1 week ago

Hi @kshitij12345 . Thank you for working on this. There are two aspects:

t-vi commented 4 days ago

@kshitij12345 Just to add: I think if we don't have a better solution this week and this enables running HF/NeMo models, I'd probably be looking at taking it as a stop-gap thing.

kshitij12345 commented 4 days ago

Sounds good, I am looking at this a bit more. Also, if required - I will sync offline with you once. Thanks for having a look.