open-mmlab / mmdetection

OpenMMLab Detection Toolbox and Benchmark
https://mmdetection.readthedocs.io
Apache License 2.0
29.5k stars 9.45k forks source link

Error happened when using onnx model converted from fcos model #1655

Closed zhoulukuan closed 4 years ago

zhoulukuan commented 4 years ago

I have converted the fcos model to onnx model successfully. However, when I want to use onnx model for inference, error happend:

==> Context: Bad node spec: input: "562" output: "563" op_type: "Slice" attribute { name: "axes" ints: 0 type: INTS } attribute { name: "ends" ints: 4 type: INTS } attribute { name: "starts" ints: 2 type: INTS }

It seems that such node has some problem, I want to know which is corresponding operation in fpn:

 %560 : Tensor = onnx::Constant[value= 1  1 [ Variable[CPUType]{2} ]](), scope: FCOS/FPN[neck]
  %561 : Tensor = onnx::Cast[to=1](%559), scope: FCOS/FPN[neck]
  %562 : Tensor = onnx::Shape(%544), scope: FCOS/FPN[neck]
  %563 : Tensor = onnx::Slice[axes=[0], ends=[4], starts=[2]](%562), scope: FCOS/FPN[neck]
  %564 : Tensor = onnx::Cast[to=1](%563), scope: FCOS/FPN[neck]
  %565 : Tensor = onnx::Div(%561, %564), scope: FCOS/FPN[neck]
  %566 : Tensor = onnx::Concat[axis=0](%560, %565), scope: FCOS/FPN[neck]
  %567 : Float(1, 256, 50, 50) = onnx::Upsample[mode="nearest"](%544, %566), scope: FCOS/FPN[neck]
  %568 : Float(1, 256, 50, 50) = onnx::Add(%543, %567), scope: FCOS/FPN[neck]
huangzicheng commented 4 years ago

I have converted the fcos model to onnx model successfully. However, when I want to use onnx model for inference, error happend:

==> Context: Bad node spec: input: "562" output: "563" op_type: "Slice" attribute { name: "axes" ints: 0 type: INTS } attribute { name: "ends" ints: 4 type: INTS } attribute { name: "starts" ints: 2 type: INTS }

It seems that such node has some problem, I want to know which is corresponding operation in fpn:

 %560 : Tensor = onnx::Constant[value= 1  1 [ Variable[CPUType]{2} ]](), scope: FCOS/FPN[neck]
  %561 : Tensor = onnx::Cast[to=1](%559), scope: FCOS/FPN[neck]
  %562 : Tensor = onnx::Shape(%544), scope: FCOS/FPN[neck]
  %563 : Tensor = onnx::Slice[axes=[0], ends=[4], starts=[2]](%562), scope: FCOS/FPN[neck]
  %564 : Tensor = onnx::Cast[to=1](%563), scope: FCOS/FPN[neck]
  %565 : Tensor = onnx::Div(%561, %564), scope: FCOS/FPN[neck]
  %566 : Tensor = onnx::Concat[axis=0](%560, %565), scope: FCOS/FPN[neck]
  %567 : Float(1, 256, 50, 50) = onnx::Upsample[mode="nearest"](%544, %566), scope: FCOS/FPN[neck]
  %568 : Float(1, 256, 50, 50) = onnx::Add(%543, %567), scope: FCOS/FPN[neck]

onnx does not support GroupNorm , fcos use GroupNorm

CoinCheung commented 4 years ago

Hi, would you please tell me how did you convert fcos to onnx model?

zhoulukuan commented 4 years ago

@CoinCheung I use torch2trt to convert. You can found this project here: https://github.com/NVIDIA-AI-IOT/torch2trt.