An AssertionError gets triggered when calling nn.ModuleDict.keys(). Here's a small repro:
import torch
import torch.nn as nn
print(torch.__version__)
class Net(nn.Module):
def __init__(self):
super().__init__()
self.submodules = nn.ModuleDict({"key": nn.Identity()})
def forward(self, x):
rv = {}
for k in self.submodules.keys():
rv[k] = self.submodules[k](x)
return rv
net = Net()
net = net.cuda()
net = torch.compile(net)
out = net(torch.Tensor([1, 2, 3, 4]))
print(out)
Error logs
Output of the above:
1.14.0a0+gitb8b7480
Traceback (most recent call last):
File "dynamo_test.py", line 20, in <module>
out = net(torch.Tensor([1, 2, 3, 4]))
File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
return fn(*args, **kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 479, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
return fn(*args, **kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 90, in time_wrapper
r = func(*args, **kwargs)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
return _compile(
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 398, in _compile
out_code = transform_code_object(code, transform)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 385, in transform
tracer.run()
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1676, in run
super().run()
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 528, in run
and self.step()
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 496, in step
getattr(self, inst.opname)(inst)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 305, in wrapper
return inner_fn(self, inst)
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 956, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 430, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 241, in call_function
return self.obj.call_method(
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 464, in call_method
return wrap_fx_proxy(
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 709, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 892, in wrap_fx_proxy_cls
raise AssertionError(
AssertionError: torch.* op returned non-Tensor odict_keys call_method keys
from user code:
File "dynamo_test.py", line 13, in forward
for k in self.submodules.keys():
Set torch._dynamo.config.verbose=True for more information
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
(It does work if I change for k in self.submodules.keys(): to for k, _ in self.submodules.items(): but would prefer not to do that ๐ )
๐ Describe the bug
An AssertionError gets triggered when calling
nn.ModuleDict.keys()
. Here's a small repro:Error logs
Output of the above:
(It does work if I change
for k in self.submodules.keys():
tofor k, _ in self.submodules.items():
but would prefer not to do that ๐ )Minified repro
No response