Open crcrpar opened 2 weeks ago
diff --git a/thunder/core/symbol.py b/thunder/core/symbol.py
index d8827eed..1d132bc4 100644
--- a/thunder/core/symbol.py
+++ b/thunder/core/symbol.py
@@ -605,15 +605,17 @@ class BoundSymbol(BoundSymbolInterface):
# BoundSymbols of Symbols without Python implementations (either because they
# have Python implementations or defined call ctxs) are assumed to need
# a module import to run properly
- assert self.sym.module is not None # TODO: Is this a valid assumption?
- module_name = self.sym.module.__name__
- import_ctx = {module_name: self.sym.module}
-
- # TODO Include the other modules on the path?
- # Also includes the root module of this (potential) submodule
- if "." in module_name:
- root_name = module_name.split(".")[0]
- import_ctx[root_name] = sys.modules[root_name]
+ if self.sym.module is not None:
+ module_name = self.sym.module.__name__
+ import_ctx = {module_name: self.sym.module}
+
+ # TODO Include the other modules on the path?
+ # Also includes the root module of this (potential) submodule
+ if "." in module_name:
+ root_name = module_name.split(".")[0]
+ import_ctx[root_name] = sys.modules[root_name]
+ else:
+ import_ctx = {}
self._import_ctx.update(import_ctx)
return self._import_ctx
lets me get a trace print:
def computation(x, weight):
# x: "cpu f64[2, 2]"
# weight: "cpu f64[2, 2]"
# /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py:13: return torch.matmul(x, weight.t())
t11 = MyLinear_125921219058160_0(x, weight, (2, 2)) # t11: "cpu f64[2, 2]"
# t9 = ltorch.t(weight) # t9: "cpu f64[2, 2]"
# t9 = prims.transpose(weight, (1, 0)) # t9: "cpu f64[2, 2]"
# t10 = ltorch.matmul(x, t9) # t10: "cpu f64[2, 2]"
# t10 = prims.matmul(x, t9) # t10: "cpu f64[2, 2]"
# t11 = prims.shallow_copy(t10) # t11: "cpu f64[2, 2]"
# /home/mkozuki/ghq/github.com/Lightning-AI/lightning-thunder/repro.py:20: return torch.matmul(grad_output, weight) * ctx.scaler, torch.matmul(grad_output.t(), x) / scaler2, None
return t11
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
If traces include one or more
BoundSymbol
s oftorch.autograd.Function
, then they are not printable.The lookaside is https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L655.
When registering a symbol, it doesn't specify a module:https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/jit_ext.py#L688-L691, arriving at https://github.com/Lightning-AI/lightning-thunder/blob/3390c922cfbe1b70c42118dfa8aa71adb3bec692/thunder/core/symbol.py#L604-L616
To Reproduce
Code sample
The outputP:
Expected behavior