IDEA-Research / DINO

[ICLR 2023] Official implementation of the paper "DINO: DETR with Improved DeNoising Anchor Boxes for End-to-End Object Detection"
Apache License 2.0
2.16k stars 234 forks source link

Support for TorchScript export #96

Open cywinski opened 2 years ago

cywinski commented 2 years ago

Hi, I have been trying to export the 4-scale model with ResNet50 backbone to TorchScript with the following code:

import torch
from util.slconfig import SLConfig
from main import build_model_main

args = SLConfig.fromfile('config_cfg.py') 
args.device = 'cuda' 
model, criterion, postprocessors = build_model_main(args)

checkpoint = torch.load('checkpoint_best_regular.pth', map_location='cuda')
model.load_state_dict(checkpoint['model'])
_ = model.eval()

traced_script_module = torch.jit.script(model.to("cuda"))

traced_script_module.save('torchscript_model.pt')

Unfortunately, I get the following error:

Traceback (most recent call last):
  File "convert_to_torchscript.py", line 64, in <module>
    traced_script_module = torch.jit.script(model.to("cuda"))
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 1257, in script
    return torch.jit._recursive.create_script_module(
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 451, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 513, in create_script_module_impl
    script_module = torch.jit.RecursiveScriptModule._construct(cpp_module, init_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 587, in _construct
    init_fn(script_module)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 491, in init_fn
    scripted = create_script_module_impl(orig_value, sub_concrete_type, stubs_fn)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 517, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/opt/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 368, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 
Python builtin <built-in method apply of FunctionMeta object at 0x6280410> is currently not supported in Torchscript:
  File "/home/ir/pvc/dino/models/dino/ops/modules/ms_deform_attn.py", line 112
            raise ValueError(
                'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
        output = MSDeformAttnFunction.apply(
                 ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
            value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
        output = self.output_proj(output)

When I used torch.jit.trace() instead of torch.jit.script() I got very similar error. As far as I know, TorchScript currently does not support the custom autograd Function. Will the TorchScript export be supported or is there any known workaround to export the model?

ZZYuting commented 1 year ago

In ops/modules/ms_deform_attn.py, replace

output = MSDeformAttnFunction.apply(

    #  value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)

by: output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) It seems work for me.

Zalways commented 1 year ago

In ops/modules/ms_deform_attn.py, replace

output = MSDeformAttnFunction.apply(

    #  value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)

by: output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) It seems work for me.

In ops/modules/ms_deform_attn.py, replace

output = MSDeformAttnFunction.apply(

    #  value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)

by: output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) It seems work for me.

i tried this ,and it really works,but the exported model cann't be used ,it will result some error infomation:(RuntimeError: The size of tensor a (138) must match the size of tensor b (237) at non-singleton dimension 1), so i'll appreciate if you could help me with my problem