pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.46k stars 22.75k forks source link

[TorchScript] Failure if you script a wrapper module and then an interface-implementing submodule. #140468

Open davidberard98 opened 1 week ago

davidberard98 commented 1 week ago

🐛 Describe the bug

Repro is below:

Since the first torchscript-ing of the wrapper module saw the submodule as an interface type, it ignores the methods that are not part of the interface. Then we cache the type. Finally, when we torchscript the submodule on its own, we see the other methods and fail because the jit_type associated with this class doesn't have those methods.

import torch

@torch.jit.interface
class MyInterface(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        pass

class MyImplementation(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * x

    @torch.jit.export
    def add_two(self, x: torch.Tensor) -> torch.Tensor:
        return x + 2

class MyWrapper(torch.nn.Module):
    impl: MyInterface

    def __init__(self):
        super().__init__()
        self.impl = MyImplementation()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.impl(x)

mod = MyWrapper()
mod_s = torch.jit.script(mod)
mod.impl = torch.jit.script(mod.impl)

error

  File "/data/users/dberard/scripts/interface_extra.py", line 31, in <module>
    mod.impl = torch.jit.script(mod.impl)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_script.py", line 1429, in script
    ret = _script_impl(
          ^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_script.py", line 1147, in _script_impl
    return torch.jit._recursive.create_script_module(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_recursive.py", line 557, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/users/dberard/pytorch/torch/jit/_recursive.py", line 679, in create_script_module_impl
    script_method = cpp_module._get_method(name)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Method 'add_two' is not defined.

Versions

main branch, CPU build

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

davidberard98 commented 1 week ago

a workaround: say your use case does this (which is also demonstrated in the repro):

1) script a outer module, which contains the inner module annotated as an interface 2) script the inner module, which fails because it was originally annotated as an interface.

you can work around the issue by first scripting the inner module to populate the type cache with the non-interface version of the module type.