facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.42k stars 7.47k forks source link

Torchscript scripting C++ model for batch inference #4564

Open H-Y-Zhu opened 2 years ago

H-Y-Zhu commented 2 years ago

Hi,

As known from the detectron2 deployment description, the detectron2 TorchScript scripting model supports dynamic batch_size. I am currently working on modifying the official example "[torchscript_mask_rcnn.cpp]" into batch inference with batch_size>1. However, it does not works.

//create a Tuple[Dict[str, Tensor]] which is the input type of scripted model

c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) {
const int height = img.rows;
const int width = img.cols;
const int channels = 3;

auto img_tensor =
    torch::from_blob(img.data, { height, width, channels }, torch::kUInt8);
// HWC to CHW
img_tensor =
    img_tensor.to(device, torch::kFloat).permute({ 2, 0, 1 }).contiguous();

cout << "img_tensor" << img_tensor.sizes() << endl;
auto img_tensor_l = img_tensor.unsqueeze(0);
cout << "img_tensor_l" << img_tensor_l.sizes() << endl;

auto dic = c10::Dict<std::string, torch::Tensor>();
dic.insert("image", img_tensor);
return std::make_tuple(dic);
}

1, The example mentions "create a Tuple[Dict[str, Tensor]] which is the input type of scripted model", so I have tried to create a

Tuple[Dict[str, Tensor], Dict[str, Tensor]]

by using

return std::make_tuple(dic, dic) 

in the return of the above function get_scripting_inputs. However, it reports the errors

"Expected a value of type 'Tuple[Dict[str, Tensor]]' for argument 'inputs' but instead found type 'Tuple[Dict[str, Tensor], Dict[str, Tensor]]'."

2, Alternatively, I have tried to stack the inputs as:

auto inputs = get_scripting_inputs(input_img_resize, device);
std::vector<c10::IValue> inputs_list;

for (int i = 0; i < 5; i++) {
    inputs_list.push_back(inputs);
}
c10::Stack stack{ inputs_list };
auto outputs = model.forward({ stack });

However, it gives the error "Expected at most 2 argument(s) for operator 'forward', but received 6 argument(s). Declaration: forward(torch.ScriptableAdapter self, (Dict(str, Tensor)) inputs) -> (Dict(str, Tensor)[])".

Can anyone provide some advices on how to realize the batch inference on the Torchscript scripting? Many Thanks.

github-actions[bot] commented 2 years ago

You've chosen to report an unexpected problem or bug. Unless you already know the root cause of it, please include details about it by filling the issue template. The following information is missing: "Instructions To Reproduce the Issue and Full Logs"; "Your Environment";

ppwwyyxx commented 2 years ago

dic.insert("image", img_tensor);

This img_tensor has shape (N, C, H, W). If N>1, it's batch inference.

H-Y-Zhu commented 2 years ago

Hi ppwwyyxx,

Thanks very much for the reply. However, it seems the shape of 'img_tensor' is (C, H, W). I tried to unsqueeze it into (1, C, H, W) and then send to the network, but it got the error: RuntimeError: AssertionError: ResNet takes an input of shape (N, C, H, W). Got [1, 1, 3, 704, 960] instead!

ppwwyyxx commented 2 years ago

Sorry, my bad.

ppwwyyxx commented 2 years ago
diff --git i/tools/deploy/export_model.py w/tools/deploy/export_model.py
index f09c5c3b..550f35d6 100755
--- i/tools/deploy/export_model.py
+++ w/tools/deploy/export_model.py
@@ -85,7 +85,7 @@ def export_scripting(torch_model):
     if isinstance(torch_model, GeneralizedRCNN):

         class ScriptableAdapter(ScriptableAdapterBase):
-            def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
+            def forward(self, inputs: List[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
                 instances = self.model.inference(inputs, do_postprocess=False)
                 return [i.get_fields() for i in instances]

diff --git i/tools/deploy/torchscript_mask_rcnn.cpp w/tools/deploy/torchscript_mask_rcnn.cpp
index fd6e1e9f..43044b25 100644
--- i/tools/deploy/torchscript_mask_rcnn.cpp
+++ w/tools/deploy/torchscript_mask_rcnn.cpp
@@ -63,7 +63,7 @@ c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) {
       img_tensor.to(device, torch::kFloat).permute({2, 0, 1}).contiguous();
   auto dic = c10::Dict<std::string, torch::Tensor>();
   dic.insert("image", img_tensor);
-  return std::make_tuple(dic);
+  return c10::List<c10::Dict<std::string, torch::Tensor>>({dic, dic});
 }

 c10::IValue

This diff will run inference with a batch size of 2. (You'll need to re-export the model)

H-Y-Zhu commented 2 years ago
diff --git i/tools/deploy/export_model.py w/tools/deploy/export_model.py
index f09c5c3b..550f35d6 100755
--- i/tools/deploy/export_model.py
+++ w/tools/deploy/export_model.py
@@ -85,7 +85,7 @@ def export_scripting(torch_model):
     if isinstance(torch_model, GeneralizedRCNN):

         class ScriptableAdapter(ScriptableAdapterBase):
-            def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
+            def forward(self, inputs: List[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
                 instances = self.model.inference(inputs, do_postprocess=False)
                 return [i.get_fields() for i in instances]

diff --git i/tools/deploy/torchscript_mask_rcnn.cpp w/tools/deploy/torchscript_mask_rcnn.cpp
index fd6e1e9f..43044b25 100644
--- i/tools/deploy/torchscript_mask_rcnn.cpp
+++ w/tools/deploy/torchscript_mask_rcnn.cpp
@@ -63,7 +63,7 @@ c10::IValue get_scripting_inputs(cv::Mat& img, c10::Device device) {
       img_tensor.to(device, torch::kFloat).permute({2, 0, 1}).contiguous();
   auto dic = c10::Dict<std::string, torch::Tensor>();
   dic.insert("image", img_tensor);
-  return std::make_tuple(dic);
+  return c10::List<c10::Dict<std::string, torch::Tensor>>({dic, dic});
 }

 c10::IValue

This diff will run inference with a batch size of 2. (You'll need to re-export the model)

Thank you very much, yes, it works now!

Zalways commented 1 year ago

hi! ppwwyyxx, i try to use the "export_model.py" to export my own model(based deformable detr) into torchscript model,when i use trace method to export, it fails, i fix the problem by add a line:"output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights)" in "ms_deform_attn.py",and then i export the model successfully,but the model doesn't work,and it shows the wrong message :RuntimeError: The size of tensor a (32) must match the size of tensor b (237) at non-singleton dimension 1. so i'm doubt that wheather the model based deformable detr can be exported into torchscript format and work well. looking forward to your reply! and hope to get your advice! @ppwwyyxx