pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 123 forks source link

AssertionError: torch.* op returned non-Tensor odict_keys call_method keys #1973

Closed ezhang887 closed 1 year ago

ezhang887 commented 1 year ago

๐Ÿ› Describe the bug

An AssertionError gets triggered when calling nn.ModuleDict.keys(). Here's a small repro:

import torch
import torch.nn as nn

print(torch.__version__)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.submodules = nn.ModuleDict({"key": nn.Identity()})

    def forward(self, x):
        rv = {}
        for k in self.submodules.keys():
            rv[k] = self.submodules[k](x)
        return rv

net = Net()
net = net.cuda()
net = torch.compile(net)
out = net(torch.Tensor([1, 2, 3, 4]))
print(out)

Error logs

Output of the above:

1.14.0a0+gitb8b7480

Traceback (most recent call last):
  File "dynamo_test.py", line 20, in <module>
    out = net(torch.Tensor([1, 2, 3, 4]))
  File "/home/user/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1480, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 479, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 90, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 339, in _convert_frame_assert
    return _compile(
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 398, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
    transformations(instructions, code_options)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 385, in transform
    tracer.run()
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1676, in run
    super().run()
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 528, in run
    and self.step()
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 496, in step
    getattr(self, inst.opname)(inst)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 305, in wrapper
    return inner_fn(self, inst)
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 956, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 430, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 241, in call_function
    return self.obj.call_method(
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 464, in call_method
    return wrap_fx_proxy(
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 709, in wrap_fx_proxy
    return wrap_fx_proxy_cls(
  File "/home/user/.local/lib/python3.8/site-packages/torch/_dynamo/variables/builder.py", line 892, in wrap_fx_proxy_cls
    raise AssertionError(
AssertionError: torch.* op returned non-Tensor odict_keys call_method keys

from user code:
   File "dynamo_test.py", line 13, in forward
    for k in self.submodules.keys():

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

(It does work if I change for k in self.submodules.keys(): to for k, _ in self.submodules.items(): but would prefer not to do that ๐Ÿ˜… )

Minified repro

No response

yanboliang commented 1 year ago

We should support .keys(), will send a PR to fix it.

yanboliang commented 1 year ago

This has been fixed by https://github.com/pytorch/pytorch/pull/90502