zhiqwang / yolort

yolort is a runtime stack for yolov5 on specialized accelerators such as tensorrt, libtorch, onnxruntime, tvm and ncnn.
https://zhiqwang.com/yolort
GNU General Public License v3.0
720 stars 153 forks source link

Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript. #153

Closed mattpopovich closed 3 years ago

mattpopovich commented 3 years ago

🐛 Bug

When attempting to load the yolov5-rt-stack model (with NMS post-processing) in C++, the following error appears:

>>> Loading model
>>> Other error: 
Unknown builtin op: torchvision::nms.
Could not find any similar ops to torchvision::nms. This op may not exist or may not be currently supported in TorchScript.
:
  File "/usr/local/lib/python3.8/dist-packages/torchvision-0.10.0a0+300a8a4-py3.8-linux-x86_64.egg/torchvision/ops/boxes.py", line 35
    """
    _assert_has_ops()
    return torch.ops.torchvision.nms(boxes, scores, iou_threshold)
           ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
Serialized   File "code/__torch__/torchvision/ops/boxes.py", line 130
  _39 = __torch__.torchvision.extension._assert_has_ops
  _40 = _39()
  _41 = ops.torchvision.nms(boxes, scores, iou_threshold)
        ~~~~~~~~~~~~~~~~~~~ <--- HERE
  return _41
'nms' is being compiled since it was called from '_batched_nms_vanilla'
  File "/usr/local/lib/python3.8/dist-packages/torchvision-0.10.0a0+300a8a4-py3.8-linux-x86_64.egg/torchvision/ops/boxes.py", line 102
    for class_id in torch.unique(idxs):
        curr_indices = torch.where(idxs == class_id)[0]
        curr_keep_indices = nms(boxes[curr_indices], scores[curr_indices], iou_threshold)
                            ~~~ <--- HERE
        keep_mask[curr_indices[curr_keep_indices]] = True
    keep_indices = torch.where(keep_mask)[0]
Serialized   File "code/__torch__/torchvision/ops/boxes.py", line 96
    _22 = torch.index(boxes, _21)
    _23 = annotate(List[Optional[Tensor]], [curr_indices])
    curr_keep_indices = __torch__.torchvision.ops.boxes.nms(_22, torch.index(scores, _23), iou_threshold, )
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    _24 = annotate(List[Optional[Tensor]], [curr_keep_indices])
    _25 = torch.index(curr_indices, _24)
'_batched_nms_vanilla' is being compiled since it was called from 'batched_nms'
Serialized   File "code/__torch__/torchvision/ops/boxes.py", line 5
    idxs: Tensor,
    iou_threshold: float) -> Tensor:
  _0 = __torch__.torchvision.ops.boxes._batched_nms_vanilla
  ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
  _1 = __torch__.torchvision.ops.boxes._batched_nms_coordinate_trick
  if torch.gt(torch.numel(boxes), 4000):
'batched_nms' is being compiled since it was called from 'PostProcess.forward'
Serialized   File "code/__torch__/yolort/models/box_head.py", line 84
    head_outputs: List[Tensor],
    anchors_tuple: Tuple[Tensor, Tensor, Tensor]) -> List[Dict[str, Tensor]]:
    _11 = __torch__.torchvision.ops.boxes.batched_nms
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    batch_size, _12, _13, _14, K, = torch.size(head_outputs[0])
    all_pred_logits = annotate(List[Tensor], [])

To Reproduce (REQUIRED)

Steps to reproduce the behavior:

  1. I was unable to reproduce this with a MCVE, so it probably is an issue with my project. But I wanted to document it here in hopes of helping someone else.

Expected behavior

The model loads via torch::jit::load() without issue.

Environment

# python3 -m torch.utils.collect_env 
Collecting environment information...
PyTorch version: 1.9.0a0+gitd69c22d
Is debug build: False
CUDA used to build PyTorch: 11.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.21.1
Libc version: glibc-2.31

Python version: 3.8 (64-bit runtime)
Python platform: Linux-5.4.0-80-generic-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 11.2.152
GPU models and configuration: 
GPU 0: GeForce GTX 1080
GPU 1: GeForce GTX 1080
GPU 2: GeForce GTX 1080

Nvidia driver version: 460.91.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.1
[pip3] pytorch-lightning==1.3.8
[pip3] torch==1.9.0a0+gitd69c22d
[pip3] torchmetrics==0.4.1
[pip3] torchvision==0.10.0a0+300a8a4
[conda] Could not collect
mattpopovich commented 3 years ago

The solution for this is to make sure you are including the following:

#include <torch/script.h>       // One-stop header; for torch::jit::load()
#include <torchvision/vision.h> // For torchvision NMS in model

If that doesn't work, I checked ldd and it looked like the torchvision library was not actually being included.

Failing executable:

# ldd path/to/failing/executable | grep torch 
        libc10.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so (0x00007fe09b466000)
        libtorch.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch.so (0x00007fe089dd3000)
        libtorch_cpu.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so (0x00007fe07fe70000)
        libtorch_cuda.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so (0x00007fe030163000)
        libc10_cuda.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libc10_cuda.so (0x00007fe01fe6b000)

Working executable (note libtorchvision.so):

# ldd /path/to/working/executable | grep torch
        libc10.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libc10.so (0x00007f293c174000)
        libtorchvision.so => /usr/local/lib/libtorchvision.so (0x00007f293bba4000)
        libtorch.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch.so (0x00007f293bb9f000)
        libtorch_cpu.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cpu.so (0x00007f2931c3c000)
        libtorch_cuda.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libtorch_cuda.so (0x00007f28e1f52000)
        libc10_cuda.so => /usr/local/lib/python3.8/dist-packages/torch/lib/libc10_cuda.so (0x00007f28e16d5000)

In order for libtorchvision.so to be successfully included, I had to #include <torchvision/models/resnet.h> and then actually use that in my code:

auto model = vision::models::ResNet18( );

That was the easiest way I found to actually use something from torchvision in your code... After that, the failing executable showed libtorchvision.so with ldd and I was able to successfully load my yolov5-rt-stack model with NMS post-processing in C++! :partying_face:

zhiqwang commented 3 years ago

Hi @mattpopovich , Thanks for the tip here, it was very useful. I had some similar ldd linking problems when I was first compiling torchvision.