Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

ThunderFX splitter should send no_grad regions to Thunder #1420

Open IvanYashchuk opened 4 days ago

IvanYashchuk commented 4 days ago

🐛 Bug

Splitter for Thunder as a Dynamo backend should send regions of code under a "no_grad" context manager to Thunder. Currently, it chooses to send these computations to Inductor.

from thunder.dynamo import ThunderCompiler
import thunder
import torch

def f(x):
    with torch.no_grad():
        return x * x

jit_f = thunder.jit(f)
backend = ThunderCompiler()
compile_f = torch.compile(backend=backend)(f)

x = torch.randn(3, 3, requires_grad=True)

out = jit_f(x) # Works with thunder.jit with a warning
out_1 = compile_f(x) # Works with torch.compile but sends the computation to the Inductor instead of Thunder

print(backend.subgraph_infos[0].split_graph_module.print_readable())

prints:

class GraphModule(torch.nn.Module):
    def forward(self, l_x_: "f32[3, 3]"):
        # No stacktrace found for following nodes
        inductor_1 = self.inductor_1(l_x_);  l_x_ = None
        return (inductor_1,)

    class inductor_1(torch.nn.Module):
        def forward(self, l_x_: "f32[3, 3]"):
            # No stacktrace found for following nodes
            _set_grad_enabled = torch._C._set_grad_enabled(False);  _set_grad_enabled = None

             # File: <ipython-input-3-f5a278fc59c9>:7 in f, code: return x * x
            mul: "f32[3, 3]" = l_x_ * l_x_;  l_x_ = None

            # No stacktrace found for following nodes
            _set_grad_enabled_1 = torch._C._set_grad_enabled(True);  _set_grad_enabled_1 = None
            return mul

HF's Qwen 2 model added in https://github.com/Lightning-AI/lightning-thunder/pull/1406 creates a small Inductor region because Qwen2RotaryEmbedding.forward is decorated with a torch.no_grad.

IvanYashchuk commented 4 days ago

Sending no_grad to Thunder was disabled in https://github.com/Lightning-AI/lightning-thunder/pull/1282 for a good reason (https://github.com/Lightning-AI/lightning-thunder/issues/1219). Maybe it's time to properly support the no_grad context in Thunder?