Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.09k
stars
63
forks
source link
Consistency in the order of applying a transform and using `thunder.jit` #185
grad transform has to be applied after thunder.jit while autocast has to be applied before thunder.jit
import torch
import torch.nn as nn
import thunder
from thunder.core.transforms import grad, autocast
from thunder.examine import examine
def foo(x):
return x
x = torch.randn(3, device='cpu')
jfoo = thunder.jit(foo)
# RuntimeError: Can only transform compiled thunder functions
# o = thunder.jit(grad(foo))(x)
# Works
o = grad(thunder.jit(foo))(x)
# Works
o = thunder.jit(autocast(foo, dtype=thunder.dtypes.bfloat16))(x)
# NotImplementedError: Attempting to execute outside of a tracing context, which is not supported
# o = autocast(thunder.jit(foo), dtype=thunder.dtypes.bfloat16)(x)
grad
transform has to be applied afterthunder.jit
whileautocast
has to be applied beforethunder.jit