Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k stars 63 forks source link

Consistency in the order of applying a transform and using `thunder.jit` #185

Open kshitij12345 opened 3 months ago

kshitij12345 commented 3 months ago

grad transform has to be applied after thunder.jit while autocast has to be applied before thunder.jit

import torch
import torch.nn as nn
import thunder
from thunder.core.transforms import grad, autocast
from thunder.examine import examine

def foo(x):
    return x

x = torch.randn(3, device='cpu')

jfoo = thunder.jit(foo)

# RuntimeError: Can only transform compiled thunder functions
# o = thunder.jit(grad(foo))(x)

# Works
o = grad(thunder.jit(foo))(x)

# Works
o = thunder.jit(autocast(foo, dtype=thunder.dtypes.bfloat16))(x)

# NotImplementedError: Attempting to execute outside of a tracing context, which is not supported
# o = autocast(thunder.jit(foo), dtype=thunder.dtypes.bfloat16)(x)
t-vi commented 1 month ago

We should aim at all transforms being applied after jit, so I'd say this is a bug in autocast.