Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
note that nvfuser executor does not work for this impl:
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b):
# a: "cuda:0 f32[2, 2]"
# b: "cuda:0 f32[2, 2]"
[c, d] = nvFusion0(a, b)
# c = prims.exp(a) # c: "cuda:0 f32[2, 2]"
# d = prims.tanh(b) # d: "cuda:0 f32[2, 2]"
# t2 = prims.add(c, d) # t2: "cuda:0 f32[2, 2]"
# t6 = prims.sub(t5, b) # t6: "cuda:0 f32[2, 2]"
# prims.copy_(t2, c)
# prims.copy_(t6, d)
del a, b
return (c, d)
Traceback (most recent call last):
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snipet.py", line 28, in <module>
main()
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/snipet.py", line 19, in main
c, d = jit_f(a, b)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/__init__.py", line 662, in fn_
result = cache_entry.computation_fn(*inps)
File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "/home/mkozuki/ghq/github.com/crcrpar/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
return func(*args, **kwargs)
File "thunder.computation_1", line 10, in computation
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 402, in __call__
fd = self.get_fd(to_descriptors(args))
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 512, in get_fd
return create_fd(bsyms, input_descriptors, sorted_unique_inputs, sorted_unique_outputs)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 274, in create_fd
translate_bound_symbol(bsym)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 264, in translate_bound_symbol
nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 1842, in sub
nva = getnv(a, fd, lc_to_nv_map)
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 116, in getnv
return lc_to_nv_map[x]
File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/utils.py", line 919, in __getitem__
return self._dict[key_]
KeyError: 't5'
torch executor only can generate the following trace:
# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(a, b):
# a: "cuda:0 f32[2, 2]"
# b: "cuda:0 f32[2, 2]"
c = torch.exp(a) # c: "cuda:0 f32[2, 2]"
# c = ltorch.exp(a) # c: "cuda:0 f32[2, 2]"
# c = prims.exp(a) # c: "cuda:0 f32[2, 2]"
d = torch.tanh(b) # d: "cuda:0 f32[2, 2]"
# d = ltorch.tanh(b) # d: "cuda:0 f32[2, 2]"
# d = prims.tanh(b) # d: "cuda:0 f32[2, 2]"
t2 = torch.add(c, d) # t2: "cuda:0 f32[2, 2]"
# t2 = ltorch.add(c, d, alpha=None) # t2: "cuda:0 f32[2, 2]"
# t2 = prims.add(c, d) # t2: "cuda:0 f32[2, 2]"
t4 = torch.div(d, a) # t4: "cuda:0 f32[2, 2]"
# t4 = ltorch.div(d, a, rounding_mode=None, out=None) # t4: "cuda:0 f32[2, 2]"
# t4 = ltorch.true_divide(d, a) # t4: "cuda:0 f32[2, 2]"
# t4 = prims.div(d, a) # t4: "cuda:0 f32[2, 2]"
del a
t6 = torch.sub(t4, b) # t6: "cuda:0 f32[2, 2]"
# t6 = ltorch.sub(t4, b, alpha=None) # t6: "cuda:0 f32[2, 2]"
# t6 = prims.sub(t4, b) # t6: "cuda:0 f32[2, 2]"
del t4, b
copy_(t2, c)
del t2
copy_(t6, d)
del t6
return (c, d)
the used snippet is:
import torch
import thunder
def f(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
c = torch.exp(a)
d = torch.tanh(b)
c += d
d.div_(a)
d.sub_(b)
return c, d
def main():
a, b = [torch.randn((2, 2), device="cuda", requires_grad=False) for _ in range(2)]
a_, b_ = a.clone().detach(), b.clone().detach()
jit_f = thunder.jit(f) #, executors=[thunder.executors.get_torch_executor()])
c, d = jit_f(a, b)
c_, d_ = f(a_, b_)
print(thunder.last_traces(jit_f)[-1])
torch.testing.assert_close(c, c_)
torch.testing.assert_close(d, d_)
if __name__ == "__main__":
main()
note that nvfuser executor does not work for this impl:
torch executor only can generate the following trace:
the used snippet is: