Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

`python_print` doesn't work as expected #312

Open carmocca opened 1 year ago

carmocca commented 1 year ago

šŸ› Bug

To Reproduce

Issue 1:

import thunder
import torch
from thunder.core import prims

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = torch.nn.ModuleList(torch.nn.Linear(1, 1) for _ in range(3))

    def forward(self, x):
        for l in self.linears:
            prims.python_print("l")
            prims.python_print(x.sum())
            x = l(x)
        return x

fn = MyModel()
x = torch.randn(10, 1)
fn = thunder.jit(fn, disable_torch_autograd=True)
fn(x)

Output:

tensor(2.7485)
tensor(6.1350)
tensor(-3.8157)

l is missing

Issue 2:

import thunder
import torch
from thunder.core import prims

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = torch.nn.ModuleList(torch.nn.Linear(1, 1) for _ in range(3))

    def forward(self, x):
        prims.python_print("l0")
        prims.python_print(x.sum())
        x = self.linears[0](x)
        prims.python_print("l1")
        prims.python_print(x.sum())
        x = self.linears[1](x)
        prims.python_print("l2")
        prims.python_print(x.sum())
        x = self.linears[2](x)
        return x

fn = MyModel()
x = torch.randn(10, 1)
fn = thunder.jit(fn, disable_torch_autograd=True)
fn(x)

Output:

l0
l1
l2
tensor(-1.1433)
tensor(8.6309)
tensor(-3.1484)

Expected behavior

The prints should appear where they should, as many times as they should.

mruberry commented 1 year ago

Awesome issues, thanks @carmocca! I'll get on them ASAP