photosynthesis-team / piq

Measures and metrics for image2image tasks. PyTorch.
Apache License 2.0
1.4k stars 120 forks source link

CLIP_IQA: Need optimized model for mobile #375

Open GaneshPulivendula2024 opened 9 months ago

GaneshPulivendula2024 commented 9 months ago

In the clip.py file of piq\feature_extractors, I am trying to save the model using torch jit script in load() function, as I need clip_iqa optimized model for android mobile.

traced_script_module = torch.jit.script(model) traced_script_module.save(os.path.join(os.path.expanduser("~/.cache/clip"), "clip_iqa_trace.pt")) traced_script_module_optimized = optimize_for_mobile(traced_script_module) traced_script_module_optimized._save_for_lite_interpreter(os.path.join(os.path.expanduser("~/.cache/clip"), "clip_iqa_mobile.ptl"))

However I am getting the below error. Please suggest how can I save the optimized clip_iqa model.

File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\piq\feature_extractors\clip.py", line 92, in load traced_script_module = torch.jit.script(model) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_script.py", line 1338, in script return torch.jit._recursive.create_script_module( File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 558, in create_script_module return create_script_module_impl(nn_module, concrete_type, stubs_fn) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 631, in create_script_module_impl script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_script.py", line 647, in _construct init_fn(script_module) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 607, in init_fn scripted = create_script_module_impl( File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 571, in create_script_module_impl method_stubs = stubs_fn(nn_module) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 898, in infer_methods_to_compile stubs.append(make_stub_from_method(nn_module, method)) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 87, in make_stub_from_method return make_stub(func, method_name) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit_recursive.py", line 71, in make_stub ast = get_jit_def(func, name, self_name="RecursiveScriptModule") File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 372, in get_jit_def return build_def( File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 433, in build_def return Def(Ident(r, def_name), decl, build_stmts(ctx, body)) File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 195, in build_stmts stmts = [build_stmt(ctx, s) for s in stmts] File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 195, in stmts = [build_stmt(ctx, s) for s in stmts] File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\torch\jit\frontend.py", line 405, in call raise UnsupportedNodeError(ctx, node) torch.jit.frontend.UnsupportedNodeError: function definitions aren't supported: File "C:\Users\mvx874\AppData\Local\anaconda3\envs\clip_iqa2\lib\site-packages\piq\feature_extractors\clip.py", line 259 def forward(self, x, return_token=False, pos_embedding=True): def stem(x):


            for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
                x = self.relu(bn(conv(x)))