import torch
from torch.nn import Module
from tltorch import FactorizedConv
class Test(Module):
def __init__(self):
super(Test, self).__init__()
self.layer = FactorizedConv(3, 4, 3, factorization='tucker', order=3)
def main():
# Instantiate the model
model = Test()
scripted_module = torch.jit.script(model)
if __name__ == "__main__":
main()
Error:
Traceback (most recent call last):
File "/workspaces/RepNet-Rex-Solutions/test.py", line 27, in <module>
main()
File "/workspaces/RepNet-Rex-Solutions/test.py", line 24, in main
save_model(model)
File "/workspaces/RepNet-Rex-Solutions/test.py", line 8, in save_model
scripted_module = torch.jit.script(model)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 1324, in script
return torch.jit._recursive.create_script_module(
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 559, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 632, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_script.py", line 639, in _construct
init_fn(script_module)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 608, in init_fn
scripted = create_script_module_impl(
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 572, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 899, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 87, in make_stub_from_method
return make_stub(func, method_name)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/_recursive.py", line 71, in make_stub
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 372, in get_jit_def
return build_def(
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 422, in build_def
param_list = build_param_list(ctx, py_def.args, self_name, pdt_arg_types)
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/torch/jit/frontend.py", line 448, in build_param_list
raise NotSupportedError(ctx_range, _vararg_kwarg_err)
torch.jit.frontend.NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:
File "/usr/local/python/3.10.8/lib/python3.10/site-packages/tltorch/factorized_tensors/core.py", line 259
def forward(self, indices=None, **kwargs):
~~~~~~~ <--- HERE
"""To use a tensor factorization within a network, use ``tensor.forward``, or, equivalently, ``tensor()`
The main issue here is torch.jit.script doesn't support variable number of arguments and keyword-only arguments with defaults which are present in the forward function of the factorized/tensorized layers.
Minimal Code:
Error:
The main issue here is
torch.jit.script
doesn't support variable number of arguments and keyword-only arguments with defaults which are present in the forward function of the factorized/tensorized layers.