open-mmlab / mmdeploy

OpenMMLab Model Deployment Framework
https://mmdeploy.readthedocs.io/en/latest/
Apache License 2.0
2.62k stars 605 forks source link

[Bug] CoreML: Pre- and post-processing for masks and padding #1477

Open Typiqally opened 1 year ago

Typiqally commented 1 year ago

Checklist

Describe the bug

I converted Mask R-CNN, to Core ML. The model conversion completes successfully and the model is working as expected. However, after conversion, the model only returns masks with the size of 28×28 (specified in the model config), instead of the default pre-processed masks which are resized to the original image. Currently, I'm rescaling the mask to fit into the bounding box width and height and filling up the gaps with bilinear-interpolation. If I visualize the masks, there seems to be a slight offset (see right and bottom sides).

output

Someone told me that this might have something to do with padding task in the pre-processing pipeline and that the original image must be a multiple of 32. I tried using an image size of 800×800, which is the result of 32·25, which still has this issue. In fact, the image above is 800×800.

Reproduction

  1. Convert the model:
    python libs/mmdeploy/tools/deploy.py \ 
    libs/mmdeploy/configs/mmdet/instance-seg/instance-seg_coreml_static-800x800.py \
    checkpoints/mask_rcnn_regnetx-3.2GF_fpn_mstrain_3x_coco.py \
    checkpoints/mask_rcnn_regnetx-3.2GF_fpn_mstrain_3x_coco_20200521_202221-99879813.pth \
    libs/validation.png \
    --work-dir out/mask_rcnn \
    --device cpu
  2. Inference
    
    import coremltools as ct
    model = ct.models.MLModel('../../tools/deploy/out/mask_rcnn/end2end.mlpackage')

out_dict = model.predict({'img_1': img})

out_dict = model.predict({'img_1': img})

detections = out_dict['detections'][0] labels = out_dict['labels'][0] masks = out_dict['masks'][0]

print(masks)

detections


3. Visualize
```python
def resize_image(array, dimensions, resampling):
    out = PIL.Image.fromarray(array).resize(dimensions, resampling)
    return np.array(out)

img_arr = img_np.copy()
mask_overlays = np.zeros((detections.shape[0], height, width))

for j, i in enumerate(detections[:, 4].argsort()):
    detection = detections[i]

    if np.any(detection == 0):
        continue

    # Define values
    x1, y1, x2, y2, confidence = detection
    x = int(x1)
    y = int(y1)
    w = int(x2 - x1)
    h = int(y2 - y1)

    # Fetch meta
    label = label_names[labels[i]]

    print(label, x, y, w, h, confidence)
    color = label_colors[labels[i]]
    mask = masks[i]
    mask = resize_image(mask, (w, h), PIL.Image.Resampling.BILINEAR)

    mask_overlays[i, y:y + h, x:x + w] = mask

for i, mask_overlay in enumerate(mask_overlays):
    for my in range(height):
        for mx in range(width):
            if mask_overlay[my, mx] > 0.4:
                img_arr[my, mx, :] = alpha_blending(img_arr[my, mx], label_colors[labels[i]], mask_overlay[my, mx] * 0.9)

plt.imshow(img_arr)
plt.show()

Environment

2022-12-01 13:27:14,584 - mmdeploy - INFO - **********Environmental information**********
2022-12-01 13:27:14,975 - mmdeploy - INFO - sys.platform: darwin
2022-12-01 13:27:14,976 - mmdeploy - INFO - Python: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ]
2022-12-01 13:27:14,976 - mmdeploy - INFO - CUDA available: False
2022-12-01 13:27:14,976 - mmdeploy - INFO - GCC: Apple clang version 14.0.0 (clang-1400.0.29.202)
2022-12-01 13:27:14,976 - mmdeploy - INFO - PyTorch: 1.9.0.post2
2022-12-01 13:27:14,976 - mmdeploy - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 4.2
  - C++ Version: 201402
  - clang 11.1.0
  - OpenMP 201811
  - NNPACK is enabled
  - CPU capability usage: NO AVX
  - Build settings: BLAS_INFO=open, BUILD_TYPE=Release, CXX_COMPILER=/Users/runner/miniforge3/conda-bld/pytorch-recipe_1629200524980/_build_env/bin/arm64-apple-darwin20.0.0-clang++, CXX_FLAGS=-ftree-vectorize -fPIC -fPIE -fstack-protector-strong -O2 -pipe -stdlib=libc++  -std=c++14 -fmessage-length=0 -isystem /Users/runner/miniforge3/conda-bld/pytorch-recipe_1629200524980/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac/include -fdebug-prefix-map=/Users/runner/miniforge3/conda-bld/pytorch-recipe_1629200524980/work=/usr/local/src/conda/pytorch-1.9.0 -fdebug-prefix-map=/Users/runner/miniforge3/conda-bld/pytorch-recipe_1629200524980/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac=/usr/local/src/conda-prefix -Wno-deprecated-declarations -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp=libomp -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-typedef-redefinition -Wno-unknown-warning-option -Wno-unused-private-field -Wno-inconsistent-missing-override -Wno-aligned-allocation-unavailable -Wno-c++14-extensions -Wno-constexpr-not-const -Wno-missing-braces -Qunused-arguments -fcolor-diagnostics -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-unused-private-field -Wno-missing-braces -Wno-c++14-extensions -Wno-constexpr-not-const, LAPACK_INFO=open, TORCH_VERSION=1.9.0, USE_CUDA=OFF, USE_CUDNN=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKLDNN=OFF, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=ON, USE_OPENMP=ON, 
2022-12-01 13:27:14,976 - mmdeploy - INFO - TorchVision: 0.10.0a0
2022-12-01 13:27:14,976 - mmdeploy - INFO - OpenCV: 4.6.0
2022-12-01 13:27:14,976 - mmdeploy - INFO - MMCV: 1.7.0
2022-12-01 13:27:14,976 - mmdeploy - INFO - MMCV Compiler: clang 14.0.0
2022-12-01 13:27:14,976 - mmdeploy - INFO - MMCV CUDA Compiler: not available
2022-12-01 13:27:14,976 - mmdeploy - INFO - MMDeploy: 0.10.0+abc7ec5
2022-12-01 13:27:14,976 - mmdeploy - INFO - 
2022-12-01 13:27:14,976 - mmdeploy - INFO - **********Backend information**********
2022-12-01 13:27:15,250 - mmdeploy - INFO - onnxruntime: 1.13.1 ops_is_avaliable : False
2022-12-01 13:27:15,252 - mmdeploy - INFO - tensorrt: None      ops_is_avaliable : False
2022-12-01 13:27:15,263 - mmdeploy - INFO - ncnn: None  ops_is_avaliable : False
2022-12-01 13:27:15,264 - mmdeploy - INFO - pplnn_is_avaliable: False
2022-12-01 13:27:15,265 - mmdeploy - INFO - openvino_is_avaliable: False
2022-12-01 13:27:15,275 - mmdeploy - INFO - snpe_is_available: False
2022-12-01 13:27:15,276 - mmdeploy - INFO - ascend_is_available: False
2022-12-01 13:27:15,737 - mmdeploy - INFO - coreml_is_available: True
2022-12-01 13:27:15,737 - mmdeploy - INFO - 
2022-12-01 13:27:15,737 - mmdeploy - INFO - **********Codebase information**********
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmdet:      2.25.3
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmseg:      None
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmcls:      None
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmocr:      None
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmedit:     None
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmdet3d:    None
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmpose:     None
2022-12-01 13:27:15,738 - mmdeploy - INFO - mmrotate:   None
2022-12-01 13:27:15,739 - mmdeploy - INFO - mmaction:   None

Error traceback

No response

Typiqally commented 1 year ago

I'm not sure whether this is the actual solution since I'm still getting mixed results, but if I apply the following patch then the offset on the masks seems to dissipate:

diff --git a/mmdeploy/backend/coreml/ops.py b/mmdeploy/backend/coreml/ops.py
index 0af1aa42..da54307d 100644
--- a/mmdeploy/backend/coreml/ops.py
+++ b/mmdeploy/backend/coreml/ops.py
@@ -77,7 +77,7 @@ def roi_align(context, node):
         normalized_coordinates=False,
         spatial_scale=extrapolation_value,
         box_coordinate_mode='CORNERS_WIDTH_FIRST',
-        sampling_mode='OFFSET_CORNERS',
+        sampling_mode='DEFAULT',
     )

     # CoreML output format: [N, 1, C, h_out, w_out]

However, now the mask is lacking information on the edges:

output

irexyc commented 1 year ago

Hi, could you provide the original image for test?

Typiqally commented 1 year ago

validation_img_800_3

Typiqally commented 1 year ago

To save some time, here is the comparison of the masks between TorchScript and Core ML converted models, using the same model as before with the visualizer from this repository:

output_pytorch

TorchScript

output_coreml Core ML

You can see that in the Core ML version, the mask is slightly offset to the top left corner, when compared to the TorchScript version.

Note: it seems like the entire bounding box is offset, which causes the mask to also be offset.

irexyc commented 1 year ago

Currently, there is no roi_align op for coreml, we use crop_resize op to extract roi feature instead of roi_align, which I don't think is mathematically equal.

The aligned parameter of roi_align is set to true by default, which will add -0.5 offset to the start of roi. You could temporarily change it like below. The result will be a little better.

And I will ask my colleagues if there is a better way.

diff --git a/mmdeploy/backend/coreml/ops.py b/mmdeploy/backend/coreml/ops.py
index 0af1aa42..36ad1796 100644
--- a/mmdeploy/backend/coreml/ops.py
+++ b/mmdeploy/backend/coreml/ops.py
@@ -51,14 +51,21 @@ def roi_align(context, node):
         const_box_info = False

     extrapolation_value = context[node.inputs[2]].val
+    aligned = inputs[6].val
     # CoreML index information along with boxes
     if const_box_info:
         boxes = context[node.inputs[1]].val
         # CoreML expects boxes/ROI in
         # [N, 1, 5, 1, 1] format
+        if aligned:
+            boxes[:, 1:] -= 0.5 / extrapolation_value
         boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1)
     else:
         boxes = inputs[1]
+        if aligned:
+            ind, boxes = mb.split(x=boxes, split_sizes=[1, 4], axis=1)
+            boxes = mb.sub(x=boxes, y=0.5 / extrapolation_value)
+            boxes = mb.concat(values=[ind, boxes], axis=1)
         boxes = mb.reshape(
             x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1])
     # Get Height and Width of crop
image
Typiqally commented 1 year ago

Alright, this seems to work as a temporary fix, thank you very much! I'll keep this issue open just in case you find an improved solution for this. If it is not possible due to Core ML incompatibility, feel free to close it.