facebookresearch / Detic

Code release for "Detecting Twenty-thousand Classes using Image-level Supervision".
Apache License 2.0
1.87k stars 211 forks source link

Export Detic to ONNX or Torchscript #107

Open ghazalmatt3r opened 1 year ago

ghazalmatt3r commented 1 year ago

Hi,

I was wondering if anyone has been able to export Detic to onnx or torchscript or serve the model in any other way? There are three other issues on this topic here but non of them reached a conclusive answer. (issues: #61 #63 #68)

I'm using detectron2's export_model script as mentioned in previous issues. Here's a snippet of the main part of the code:

import argparse
import os
import torch
import sys
from pathlib import Path
from torch import Tensor, nn

import detectron2.data.transforms as T
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import build_detection_test_loader, detection_utils
from detectron2.export import TracingAdapter, dump_torchscript_IR
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.projects.point_rend import add_pointrend_config
from detectron2.structures import Boxes
from detectron2.utils.env import TORCH_VERSION
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] / 'Detic' # Detic root directory
CONFIGS = ROOT / 'configs'

if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))  # add ROOT to PATH

sys.path.insert(0, str(ROOT / 'third_party/CenterNet2/'))

os.chdir(ROOT)

from centernet.config import add_centernet_config
from detic.config import add_detic_config

def setup_cfg(args):
    cfg = get_cfg()
    # cuda context is initialized before creating dataloader, so we don't fork anymore
    cfg.DATALOADER.NUM_WORKERS = 0
    add_pointrend_config(cfg)
    add_centernet_config(cfg)
    add_detic_config(cfg)
    cfg.merge_from_file(args.config_file)
    print(args.opts)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg

# experimental. API not yet final
def export_tracing(torch_model, inputs):
    assert TORCH_VERSION >= (1, 8)
    image = inputs[0]["image"]
    inputs = [{"image": image}]  # remove other unused keys

    if isinstance(torch_model, GeneralizedRCNN):

        def inference(model, inputs):
            # use do_postprocess=False so it returns ROI mask
            inst = model.inference(inputs, do_postprocess=False)[0]
            return [{"instances": inst}]

    else:
        inference = None  # assume that we just call the model directly

    traceable_model = TracingAdapter(torch_model, inputs, inference)
    traceable_model.eval()

    if args.format == "torchscript":
        ts_model = torch.jit.trace(traceable_model, (image,))
        with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
            torch.jit.save(ts_model, f)
        dump_torchscript_IR(ts_model, args.output)
    elif args.format == "onnx":
        with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f:
            torch.onnx.export(traceable_model, (image,), f, opset_version=16)
    logger.info("Inputs schema: " + str(traceable_model.inputs_schema))
    logger.info("Outputs schema: " + str(traceable_model.outputs_schema))

def get_sample_inputs(args):

    if args.sample_image is None:
        # get a first batch from dataset
        data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
        first_batch = next(iter(data_loader))
        return first_batch
    else:
        # get a sample data
        original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT)
        # Do same preprocessing as DefaultPredictor
        aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )
        height, width = original_image.shape[:2]
        image = aug.get_transform(original_image).apply_image(original_image)
        image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

        inputs = {"image": image, "height": height, "width": width}

        # Sample ready
        sample_inputs = [inputs]
        return sample_inputs

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Export a model for deployment.")
    parser.add_argument(
        "--format",
        choices=["onnx", "torchscript"],
        help="output format",
        default="onnx",
    )
    parser.add_argument(
        "--export-method",
        choices=["tracing"],
        help="Method to export models",
        default="tracing",
    )
    parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
    parser.add_argument("--sample-image", default=None, type=str, help="sample image for input")
    parser.add_argument("--output", help="output directory for the converted model")
    parser.add_argument(
        "opts",
        help="Modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    logger = setup_logger()
    logger.info("Command line arguments: " + str(args))
    PathManager.mkdirs(args.output)

    cfg = setup_cfg(args)
    print(cfg)

    # create a torch model
    torch_model = build_model(cfg)
    DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS)

    # get sample data
    sample_inputs = get_sample_inputs(args)

    assert TORCH_VERSION >= (1, 8)
    image = sample_inputs[0]["image"]
    inputs = [{"image": image}]  # remove other unused keys

    # convert and save model
    if args.export_method == "tracing":
        exported_model = export_tracing(torch_model, sample_inputs)

This is how I'm running the code:

python export_model.py --config-file $detic_config_dir/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml  --output ./output  --export-method tracing --format torchscript/onnx --sample-image ./00001.jpg MODEL.WEIGHTS $detic_ckpt_dir/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth MODEL.DEVICE cuda

I removed this line from custom_rcnn.py and exported the model to both onnx and torchscript formats:

box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages)

There are soooo many tracer warnings (log in the first comment) but triton can serve both models (Status: READY). The problem is I'm still not able to get any inference results from the models. Here are the errors I'm getting:

ONNX:

root@lambda-quad:/workspace# python client.py 
Traceback (most recent call last):
  File "client.py", line 48, in <module>
    results = client.infer(model_name="detic_onnx", inputs=[request], outputs=outputs)
  File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 1490, in infer
    _raise_if_error(response)
  File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 65, in _raise_if_error
    raise error
tritonclient.utils.InferenceServerException: onnx runtime error 1: Non-zero status code returned while running Split node. Name:'/roi_heads/Split_3' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={0,22048} NumOutputs=1 Num entries in 'split' (must equal number of outputs) was 1 Sum of sizes in 'split' (must equal size of selected axis) was 256

Torchscript:

root@lambda-quad:/workspace# python client.py 
Traceback (most recent call last):
  File "client.py", line 48, in <module>
    results = client.infer(model_name="detic_torchscript", inputs=[request], outputs=outputs)
  File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 1490, in infer
    _raise_if_error(response)
  File "/usr/local/lib/python3.8/dist-packages/tritonclient/http/__init__.py", line 65, in _raise_if_error
    raise error
tritonclient.utils.InferenceServerException: PyTorch execute failure: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/detectron2/export/flatten.py", line 26, in forward
    image_size = torch.stack([_1, _2])
    max_size, _3 = torch.max(torch.stack([image_size]), 0)
    _4 = torch.div(torch.add(max_size, CONSTANTS.c0), CONSTANTS.c1, rounding_mode="floor")
         ~~~~~~~~~ <--- HERE
    max_size0 = torch.mul(_4, CONSTANTS.c1)
    _5 = torch.sub(torch.select(max_size0, 0, -1), torch.select(image_size, 0, 1))

Traceback of TorchScript, original code (most recent call last):
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/image_list.py(101): from_tensors
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/meta_arch/rcnn.py(229): preprocess_image
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/meta_arch/custom_rcnn.py(96): inference
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/export_model.py(60): inference
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/export/flatten.py(294): forward
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py(1182): _slow_forward
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/nn/modules/module.py(1194): _call_impl
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/export_model.py(70): export_tracing
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/export_model.py(172): <module>
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Are these because of the tracer warnings? Can anyone help me solve this please?

@anshudaur Were you able to successfully get inference from the model you exported? If yes, would it be possible to share your code and your model?

Thank you!

ghazalmatt3r commented 1 year ago

Tracer Warning log:

/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/image_list.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert t.shape[:-2] == tensors[0].shape[:-2], t.shape
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:430: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if W % self.patch_size[1] != 0:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:432: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if H % self.patch_size[0] != 0:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:369: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  Hp = int(np.ceil(H / self.window_size)) * self.window_size
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:370: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  Wp = int(np.ceil(W / self.window_size)) * self.window_size
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:210: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert L == H * W, "input feature has wrong size"
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:73: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  B = int(windows.shape[0] / (H * W / window_size / window_size))
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:248: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_r > 0 or pad_b > 0:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:279: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert L == H * W, "input feature has wrong size"
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:284: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  pad_input = (H % 2 == 1) or (W % 2 == 1)
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:285: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_input:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:188: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  shapes_per_level = grids[0].new_tensor(
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/nms.py:15: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:739: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  keep = cls_scores >= image_thresh.item()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/roi_heads/fast_rcnn.py:138: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not valid_mask.all():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/roi_heads/fast_rcnn.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_bbox_reg_classes == 1:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/nms.py:15: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/roi_heads/mask_head.py:139: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if cls_agnostic_mask:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/image_list.py:85: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert t.shape[:-2] == tensors[0].shape[:-2], t.shape
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:430: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if W % self.patch_size[1] != 0:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:432: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if H % self.patch_size[0] != 0:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:369: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  Hp = int(np.ceil(H / self.window_size)) * self.window_size
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:370: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  Wp = int(np.ceil(W / self.window_size)) * self.window_size
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:210: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert L == H * W, "input feature has wrong size"
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:73: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  B = int(windows.shape[0] / (H * W / window_size / window_size))
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:248: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_r > 0 or pad_b > 0:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:279: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert L == H * W, "input feature has wrong size"
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:284: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  pad_input = (H % 2 == 1) or (W % 2 == 1)
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/detic/modeling/backbone/swintransformer.py:285: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if pad_input:
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:188: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  shapes_per_level = grids[0].new_tensor(
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:692: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if per_candidate_inds.sum().item() > per_pre_nms_top_n.item():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/nms.py:15: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/projects/Real2Sim/object_tracking/matt3r_tracking/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py:739: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  keep = cls_scores >= image_thresh.item()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/instances.py:147: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
  return v.__len__()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/roi_heads/fast_rcnn.py:138: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if not valid_mask.all():
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if tensor.numel() == 0:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:191: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!"
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/structures/boxes.py:192: TracerWarning: Iterating over a tensor might cause the trace to be incorrect. Passing a tensor of different shape won't change the number of iterations executed (and might lead to errors or silently give incorrect results).
  h, w = box_size
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/roi_heads/fast_rcnn.py:155: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if num_bbox_reg_classes == 1:
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/nms.py:15: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert boxes.shape[-1] == 4
/data2/ghazal/miniconda3/lib/python3.9/site-packages/torch/__init__.py:853: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert condition, message
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/layers/roi_align.py:55: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert rois.dim() == 2 and rois.size(1) == 5
/data2/ghazal/miniconda3/lib/python3.9/site-packages/detectron2/modeling/roi_heads/mask_head.py:139: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if cls_agnostic_mask:
HtwoOtwo commented 12 months ago

I am facing the same problem with you, have you solved it?

anshudaur commented 12 months ago

@ghazalmatt3r , @HtwoOtwo you also have to comment the nms_and_topk line in centernet, while exporting the model

boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms)

Also, both models onnx and torchscript works fine in triton inference server.

Hope it helps. ✌️

HtwoOtwo commented 12 months ago

@ghazalmatt3r , @HtwoOtwo you also have to comment the nms_and_topk line in centernet, while exporting the model

boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms)

Also, both models onnx and torchscript works fine in triton inference server.

Hope it helps. ✌️

Wow, it works, thanks a lot!

gigasurgeon commented 11 months ago

@HtwoOtwo do you have an inference script to infer from ONNX? I have successfully exported to ONNX but don't know how to infer from it.

HtwoOtwo commented 11 months ago

@HtwoOtwo do you have an inference script to infer from ONNX? I have successfully exported to ONNX but don't know how to infer from it. Hi, you can try the code below:


import argparse
import cv2
import numpy as np
import onnxruntime as ort

class Detic(): def init(self, modelpath, detection_width=800, confThreshold=0.8): providers = ['CUDAExecutionProvider'] self.session = ort.InferenceSession(modelpath, providers=providers) model_inputs = self.session.get_inputs() self.input_name = model_inputs[0].name self.max_size = detection_width self.confThreshold = confThreshold

self.class_names = list(map(lambda x: x.strip(), open('models/class_names.txt').readlines()))

    # self.assigned_colors = np.random.randint(0,high=256, size=(len(self.class_names), 3)).tolist()
    self.assigned_colors = np.random.randint(0,high=256, size=(4, 3)).tolist()

def preprocess(self, srcimg):
    im_h, im_w, _ = srcimg.shape
    dstimg = cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)
    if im_h < im_w:
        scale = self.max_size / im_h
        oh, ow = self.max_size, scale * im_w
    else:
        scale = self.max_size / im_w
        oh, ow = scale * im_h, self.max_size

    max_hw = max(oh, ow)
    if max_hw > self.max_size:
        scale = self.max_size / max_hw
        oh *= scale
        ow *= scale
    ow = int(ow + 0.5)
    oh = int(oh + 0.5)
    dstimg = cv2.resize(dstimg, (1067, 800))
    return dstimg

def post_processing(self, pred_boxes, scores, pred_classes, pred_masks, im_hw, pred_hw):
    scale_x, scale_y = (im_hw[1] / pred_hw[1], im_hw[0] / pred_hw[0])

    pred_boxes[:, 0::2] *= scale_x
    pred_boxes[:, 1::2] *= scale_y
    pred_boxes[:, [0, 2]] = np.clip(pred_boxes[:, [0, 2]], 0, im_hw[1])
    pred_boxes[:, [1, 3]] = np.clip(pred_boxes[:, [1, 3]], 0, im_hw[0])

    threshold = 0
    widths = pred_boxes[:, 2] - pred_boxes[:, 0]
    heights = pred_boxes[:, 3] - pred_boxes[:, 1]
    keep = (widths > threshold) & (heights > threshold)

    pred_boxes = pred_boxes[keep]
    scores = scores[keep]
    pred_classes = pred_classes[keep]
    pred_masks = pred_masks[keep]

    # mask_threshold = 0.5
    # pred_masks = paste_masks_in_image(
    #     pred_masks[:, 0, :, :], pred_boxes,
    #     (im_hw[0], im_hw[1]), mask_threshold
    # )
    threshold = 0.5
    idx = scores>threshold
    scores = scores[idx]
    pred_boxes = pred_boxes[idx]
    pred_classes = pred_classes[idx]
    pred_masks = pred_masks[idx]

    pred = {
        'pred_boxes': pred_boxes,
        'scores': scores,
        'pred_classes': pred_classes,
        'pred_masks': pred_masks,
    }
    return pred

def draw_predictions(self, img, predictions):
    height, width = img.shape[:2]
    default_font_size = int(max(np.sqrt(height * width) // 90, 10))
    boxes = predictions["pred_boxes"].astype(np.int64)
    scores = predictions["scores"]
    classes_id = predictions["pred_classes"].tolist()
    # masks = predictions["pred_masks"].astype(np.uint8)
    num_instances = len(boxes)
    print('detect', num_instances, 'instances')
    for i in range(num_instances):
        x0, y0, x1, y1 = boxes[i]
        color = self.assigned_colors[classes_id[i]]
        color = [0,1,0]
        cv2.rectangle(img, (x0, y0), (x1, y1), color=color,thickness=default_font_size // 4)
        # text = "{} {:.0f}%".format(self.class_names[classes_id[i]], round(scores[i],2) * 100)
        text = "{} {:.0f}%".format(classes_id[i], round(scores[i],2) * 100)
        cv2.putText(img, text, (x0, y0 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, thickness=1, lineType=cv2.LINE_AA)
    return img

def detect(self, srcimg):
    im_h, im_w = srcimg.shape[:2]
    dstimg = self.preprocess(srcimg)
    pred_hw = dstimg.shape[:2]
    input_image = dstimg.transpose(2, 0, 1).astype(np.float32)
    # input_image = np.expand_dims(dstimg.transpose(2, 0, 1), axis=0).astype(np.float32)

    # Inference
    pred_boxes, pred_classes, pred_masks, scores, _ = self.session.run(None, {self.input_name: input_image})
    preds = self.post_processing(pred_boxes, scores, pred_classes, pred_masks, (im_h, im_w), pred_hw)
    return preds

if name == 'main': parser = argparse.ArgumentParser() parser.add_argument("--imgpath", type=str, help="image path") parser.add_argument("--confThreshold", default=0.5, type=float, help='class confidence') parser.add_argument("--modelpath", type=str, default='onnx_models/model.onnx', help="onnxmodel path") args = parser.parse_args()

mynet = Detic(args.modelpath, confThreshold=args.confThreshold)
srcimg = cv2.imread(args.imgpath)
preds = mynet.detect(srcimg)
result = mynet.draw_predictions(srcimg, preds)

cv2.imwrite('result.jpg', result)
gigasurgeon commented 11 months ago

@HtwoOtwo Thank you so much for the inference script. How do we provide custom vocabulary to the model? I have a list of custom classes that I would like to use.

Jialeen commented 8 months ago

If the onnx model can run with c++? I get error when run with c++ : [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Split node. Name:'/roi_heads/Split_3' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={240,7} NumOutputs=1 Num entries in 'split' (must equal number of outputs) was 1 Sum of sizes in 'split' (must equal size of selected axis) was 256 terminate called after throwing an instance of 'Ort::Exception' what(): Non-zero status code returned while running Split node. Name:'/roi_heads/Split_3' Status Message: Cannot split using values in 'split' attribute. Axis=0 Input shape={240,7} NumOutputs=1 Num entries in 'split' (must equal number of outputs) was 1 Sum of sizes in 'split' (must equal size of selected axis) was 256