In the clip.py file of piq\feature_extractors, I am trying to save the model using torch jit script in load() function, as I need clip_iqa optimized model for android mobile.
However I am getting the below error. Please suggest how can I save the optimized clip_iqa model.
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\piq\feature_extractors\clip.py", line 92, in load
traced_script_module = torch.jit.script(model)
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_script.py", line 1338, in script
return torch.jit._recursive.create_script_module(
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 558, in create_script_module
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 631, in create_script_module_impl
script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_script.py", line 647, in _construct
init_fn(script_module)
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 607, in init_fn
scripted = create_script_module_impl(
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 571, in create_script_module_impl
method_stubs = stubs_fn(nn_module)
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 898, in infer_methods_to_compile
stubs.append(make_stub_from_method(nn_module, method))
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 87, in make_stub_from_method
return make_stub(func, method_name)
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 71, in make_stub
ast = get_jit_def(func, name, self_name="RecursiveScriptModule")
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 372, in get_jit_def
return build_def(
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 433, in build_def
return Def(Ident(r, def_name), decl, build_stmts(ctx, body))
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 195, in build_stmts
stmts = [build_stmt(ctx, s) for s in stmts]
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 195, in
stmts = [build_stmt(ctx, s) for s in stmts]
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 405, in call
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\piq\feature_extractors\clip.py", line 259
def forward(self, x, return_token=False, pos_embedding=True):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
In the clip.py file of piq\feature_extractors, I am trying to save the model using torch jit script in load() function, as I need clip_iqa optimized model for android mobile.
traced_script_module = torch.jit.script(model) traced_script_module.save(os.path.join(os.path.expanduser("~/.cache/clip"), "clip_iqa_trace.pt")) traced_script_module_optimized = optimize_for_mobile(traced_script_module) traced_script_module_optimized._save_for_lite_interpreter(os.path.join(os.path.expanduser("~/.cache/clip"), "clip_iqa_mobile.ptl"))
However I am getting the below error. Please suggest how can I save the optimized clip_iqa model.
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\piq\feature_extractors\clip.py", line 92, in load traced_script_module = torch.jit.script(model) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_script.py", line 1338, in script return torch.jit._recursive.create_script_module( File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 558, in create_script_module return create_script_module_impl(nn_module, concrete_type, stubs_fn) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 631, in create_script_module_impl script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_script.py", line 647, in _construct init_fn(script_module) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 607, in init_fn scripted = create_script_module_impl( File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 571, in create_script_module_impl method_stubs = stubs_fn(nn_module) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 898, in infer_methods_to_compile stubs.append(make_stub_from_method(nn_module, method)) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 87, in make_stub_from_method return make_stub(func, method_name) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 71, in make_stub ast = get_jit_def(func, name, self_name="RecursiveScriptModule") File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 372, in get_jit_def return build_def( File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 433, in build_def return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 195, in build_stmts stmts = [build_stmt(ctx, s) for s in stmts] File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 195, in
stmts = [build_stmt(ctx, s) for s in stmts]
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 405, in call
raise UnsupportedNodeError(ctx, node)
torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported:
File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\piq\feature_extractors\clip.py", line 259
def forward(self, x, return_token=False, pos_embedding=True):
def stem(x):