Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
import thunder
import torch
from thunder.core import prims
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linears = torch.nn.ModuleList(torch.nn.Linear(1, 1) for _ in range(3))
def forward(self, x):
for l in self.linears:
prims.python_print("l")
prims.python_print(x.sum())
x = l(x)
return x
fn = MyModel()
x = torch.randn(10, 1)
fn = thunder.jit(fn, disable_torch_autograd=True)
fn(x)
Output:
tensor(2.7485)
tensor(6.1350)
tensor(-3.8157)
l is missing
Issue 2:
import thunder
import torch
from thunder.core import prims
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linears = torch.nn.ModuleList(torch.nn.Linear(1, 1) for _ in range(3))
def forward(self, x):
prims.python_print("l0")
prims.python_print(x.sum())
x = self.linears[0](x)
prims.python_print("l1")
prims.python_print(x.sum())
x = self.linears[1](x)
prims.python_print("l2")
prims.python_print(x.sum())
x = self.linears[2](x)
return x
fn = MyModel()
x = torch.randn(10, 1)
fn = thunder.jit(fn, disable_torch_autograd=True)
fn(x)
š Bug
To Reproduce
Issue 1:
Output:
l
is missingIssue 2:
Output:
Expected behavior
The prints should appear where they should, as many times as they should.