Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

support returning dataclasses with tensors #623

Closed t-vi closed 4 days ago

t-vi commented 1 week ago

This needs to work for transformers (and thus is high-impact / high prio):

from dataclasses import dataclass
import torch
import thunder

@dataclass
class MyContainer:
    res: torch.Tensor
    num: int

def fn(x):
    return MyContainer(x, 1)

jfn = thunder.jit(fn)

x = torch.randn(5, 5)

print(jfn(x))

(somewhat related is #461 , but that is to not have infinite recursion, here we want the feature)

cc @apaz-cli