tensorly / torch

TensorLy-Torch: Deep Tensor Learning with TensorLy and PyTorch
http://tensorly.org/torch/
BSD 3-Clause "New" or "Revised" License
70 stars 18 forks source link

`torch.jit.script` does not work with Tensorized Models #33

Open hello-fri-end opened 7 months ago

hello-fri-end commented 7 months ago

Minimal Code:

import torch
from torch.nn import Module
from tltorch import FactorizedConv

class Test(Module):
    def __init__(self):
        super(Test, self).__init__()
        self.layer = FactorizedConv(3, 4, 3, factorization='tucker', order=3)

def main():
# Instantiate the model
    model = Test()
    scripted_module = torch.jit.script(model)

if __name__ == "__main__":
    main()

Error:

Traceback (most recent call last):
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 27, in <module>
    main()
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 24, in main
    save_model(model)
  File "/workspaces/RepNet-Rex-Solutions/test.py", line 8, in save_model
    scripted_module = torch.jit.script(model)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 1324, in script
    return torch.jit._recursive.create_script_module(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
    init_fn(script_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
    scripted = create_script_module_impl(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 572, in create_script_module_impl
    method_stubs = stubs_fn(nn_module)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 899, in infer_methods_to_compile
    stubs.append(make_stub_from_method(nn_module, method))
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 87, in make_stub_from_method
    return make_stub(func, method_name)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 71, in make_stub
    ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 372, in get_jit_def
    return build_def(
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 422, in build_def
    param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 448, in build_param_list
    raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
  File "/usr/local/python/3.10.8/lib/python3.10/site-packages/tltorch/factorized_tensors/core.py", line 259
    def forward(self, indices=None, **kwargs):
                                     ~~~~~~~ <--- HERE
        """To use a tensor factorization within a network, use ``tensor.forward``, or, equivalently, ``tensor()`

The main issue here is torch.jit.script doesn't support variable number of arguments and keyword-only arguments with defaults which are present in the forward function of the factorized/tensorized layers.

JeanKossaifi commented 4 weeks ago

So does removing the **kwargs fix the issue? Would you be able to open a small PR if so?