Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.21k stars 80 forks source link

Traces with bsyms of `torch.autograd.Function` are not printable #1404

Open crcrpar opened 2 weeks ago

crcrpar commented 2 weeks ago

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

If traces include one or more BoundSymbols of torch.autograd.Function, then they are not printable.

The lookaside is https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L655.

When registering a symbol, it doesn't specify a module:https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L688-L691, arriving at https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/symbol.py#L604-L616

To Reproduce

Code sample

import torch

import thunder

class MyLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, weight: torch.Tensor, shape: tuple[int, int]) -> torch.Tensor:
        ctx.shape = shape
        ctx.save_for_backward(x, weight)
        ctx.pretty_attr = 100
        ctx.scaler = 1.0
        return torch.matmul(x, weight.t())

    @staticmethod
    def backward(ctx, grad_output):
        (x, weight) = ctx.saved_tensors
        assert weight.shape == ctx.shape  # really bogus, just to use ctx.shape
        scaler2 = ctx.shape[0] / ctx.shape[1]
        return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2, None

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(2, 2, bias=False)

    def forward(self, x):
        return MyLinear.apply(x, self.l1.weight, self.l1.weight.shape)

if __name__ == "__main__":
    x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
    model = Model().to(dtype=torch.float64)
    jitted = thunder.jit(model)

    jitted(x)
    print(thunder.last_traces(jitted)[-1])
    print(thunder.last_traces(jitted)[0])

The outputP:

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(x, weight):
  # x: "cpu f64[2, 2]"
  # weight: "cpu f64[2, 2]"
  t21 = torch.permute(weight, (1, 0))  # t21: "cpu f64[2, 2]"
    # t21 = ltorch.permute(weight, (1, 0))  # t21: "cpu f64[2, 2]"
      # t21 = prims.transpose(weight, (1, 0))  # t21: "cpu f64[2, 2]"
  t22 = torch.matmul(x, t21)  # t22: "cpu f64[2, 2]"
    # t22 = ltorch.matmul(x, t21)  # t22: "cpu f64[2, 2]"
      # t22 = prims.matmul(x, t21)  # t22: "cpu f64[2, 2]"
  del t21
  t11 = shallow_copy(t22)  # t11: "cpu f64[2, 2]"
  del t22
  return {'output': t11, 'flat_args': [x, weight], 'flat_output': (t11,)}, ((weight, x), ())
Traceback (most recent call last):
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py", line 39, in <module>
    print(thunder.last_traces(jitted)[0])
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 510, in __repr__
    return self.python(print_depth=-1)
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 363, in python
    import_ctx, call_ctx, object_ctx = self._gather_ctxs()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/trace.py", line 322, in _gather_ctxs
    bsym_import_ctx, bsym_call_ctx, bsym_object_ctx = bsym.gather_ctxs()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/symbol.py", line 648, in gather_ctxs
    return self.import_ctx(), self._get_call_ctx(), self.object_ctx()
  File "/home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/thunder/core/symbol.py", line 608, in import_ctx
    assert self.sym.module is not None  # TODO: Is this a valid assumption?
AssertionError

Expected behavior

crcrpar commented 2 weeks ago
diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py
index d8827eed..1d132bc4 100644
--- a/thunder/core/symbol.py
+++ b/thunder/core/symbol.py
@@ -605,15 +605,17 @@ class BoundSymbol(BoundSymbolInterface):
             # BoundSymbols of Symbols without Python implementations (either because they
             #   have Python implementations or defined call ctxs) are assumed to need
             #   a module import to run properly
-            assert self.sym.module is not None  # TODO: Is this a valid assumption?
-            module_name = self.sym.module.__name__
-            import_ctx = {module_name: self.sym.module}
-
-            # TODO Include the other modules on the path?
-            # Also includes the root module of this (potential) submodule
-            if "." in module_name:
-                root_name = module_name.split(".")[0]
-                import_ctx[root_name] = sys.modules[root_name]
+            if self.sym.module is not None:
+                module_name = self.sym.module.__name__
+                import_ctx = {module_name: self.sym.module}
+
+                # TODO Include the other modules on the path?
+                # Also includes the root module of this (potential) submodule
+                if "." in module_name:
+                    root_name = module_name.split(".")[0]
+                    import_ctx[root_name] = sys.modules[root_name]
+            else:
+                import_ctx = {}

         self._import_ctx.update(import_ctx)
         return self._import_ctx

lets me get a trace print:

def computation(x, weight):
  # x: "cpu f64[2, 2]"
  # weight: "cpu f64[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py:13:            return torch.matmul(x, weight.t())
  t11 = MyLinear_125921219058160_0(x, weight, (2, 2))  # t11: "cpu f64[2, 2]"
    # t9 = ltorch.t(weight)  # t9: "cpu f64[2, 2]"
      # t9 = prims.transpose(weight, (1, 0))  # t9: "cpu f64[2, 2]"
    # t10 = ltorch.matmul(x, t9)  # t10: "cpu f64[2, 2]"
      # t10 = prims.matmul(x, t9)  # t10: "cpu f64[2, 2]"
    # t11 = prims.shallow_copy(t10)  # t11: "cpu f64[2, 2]"

  # /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py:20:            return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2, None
  return t11