AILab-CVC / YOLO-World

[CVPR 2024] Real-Time Open-Vocabulary Object Detection
https://www.yoloworld.cc
GNU General Public License v3.0
4.75k stars 460 forks source link

VRAM requirements during inference? #90

Open destroy314 opened 9 months ago

destroy314 commented 9 months ago

I tried to use a graphics card with 8GB video memory to run the image_demo.py demonstration of the YOLO-World-Seg-L model, but encountered a CUDA out of memory problem. What are the graphics memory requirements to run these models? Is there a way to reduce it?

wondervictor commented 9 months ago

Hi @destroy314, it's a little confusing and I'm checking it. Could you use mixed precision for inference?

destroy314 commented 9 months ago

Sure, but passing the --amp option resulted in the following error:

Traceback (most recent call last):
  File "image_demo.py", line 158, in <module>
    inference_detector(runner,
  File "image_demo.py", line 79, in inference_detector
    output = runner.model.test_step(data_batch)[0]
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 145, in test_step
    return self._run_forward(data, mode='predict')  # type: ignore
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 361, in _run_forward
    results = self(**data, mode=mode)
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmdet/models/detectors/base.py", line 94, in forward
    return self.predict(inputs, data_samples)
  File "/home/yangzhuo/YOLO-World/yolo_world/models/detectors/yolo_world.py", line 45, in predict
    results_list = self.bbox_head.predict(img_feats,
  File "/home/yangzhuo/YOLO-World/yolo_world/models/dense_heads/yolo_world_seg_head.py", line 326, in predict
    predictions = self.predict_by_feat(*outs,
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmyolo/models/dense_heads/yolov5_ins_head.py", line 631, in predict_by_feat
    results = self._bbox_post_process(
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmdet/models/dense_heads/base_dense_head.py", line 485, in _bbox_post_process
    det_bboxes, keep_idxs = batched_nms(bboxes, results.scores,
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmcv/ops/nms.py", line 303, in batched_nms
    dets, keep = nms_op(boxes_for_nms, scores, **nms_cfg_)
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmengine/utils/misc.py", line 395, in new_func
    output = old_func(*args, **kwargs)
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmcv/ops/nms.py", line 127, in nms
    inds = NMSop.apply(boxes, scores, iou_threshold, offset, score_threshold,
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/yangzhuo/mambaforge/envs/yolo_world/lib/python3.8/site-packages/mmcv/ops/nms.py", line 27, in forward
    inds = ext_module.nms(
RuntimeError: expected scalar type Float but found Half

My command is:

$ python image_demo.py configs/segmentation/yolo_world_seg_l_dual_vlpan_2e-4_80e_8gpus_seghead_finetune_lvis.py yolo_world_seg_l_dual_vlpan_2e-4_80e_8gpus_seghead_finetune_lvis-5a642d30.pth ../2107540987.jpg 'pineapple, grape, pear, carrot, orange, mouse, Rubiks cube, apple, mushroom, toilet paper, mineral water, handle, tape, rag, toothpaste' --topk 100 --threshold 0.005 --output-dir demo_outputs --amp

My pip list output is as follows:

Click me ``` Package Version Editable project location ----------------------------- ---------- ------------------------- actionlib 1.14.0 addict 2.4.0 aliyun-python-sdk-core 2.14.0 aliyun-python-sdk-kms 2.16.2 angles 1.9.13 bondpy 1.8.6 camera-calibration 1.17.0 camera-calibration-parsers 1.12.0 catkin 0.8.10 certifi 2024.2.2 cffi 1.16.0 charset-normalizer 3.3.2 click 8.1.7 cmake 3.28.3 colorama 0.4.6 contourpy 1.1.1 controller-manager 0.20.0 controller-manager-msgs 0.20.0 crcmod 1.7 cryptography 42.0.5 cv-bridge 1.16.2 cycler 0.12.1 defusedxml 0.7.1 diagnostic-analysis 1.11.0 diagnostic-common-diagnostics 1.11.0 diagnostic-updater 1.11.0 dynamic-reconfigure 1.7.3 filelock 3.13.1 fonttools 4.49.0 fsspec 2024.2.0 gazebo_plugins 2.9.2 gazebo_ros 2.9.2 gencpp 0.7.0 geneus 3.0.0 genlisp 0.4.18 genmsg 0.6.0 gennodejs 2.0.2 genpy 0.6.15 huggingface-hub 0.21.3 idna 3.6 image-geometry 1.16.2 importlib-metadata 7.0.1 importlib_resources 6.1.2 interactive-markers 1.12.0 Jinja2 3.1.3 jmespath 0.10.0 joint-state-publisher 1.15.1 joint-state-publisher-gui 1.15.1 kiwisolver 1.4.5 laser_geometry 1.6.7 lit 17.0.6 Markdown 3.5.2 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib 3.7.5 mdurl 0.1.2 message-filters 1.16.0 mmcv 2.1.0 mmcv-lite 2.0.1 mmdet 3.3.0 mmengine 0.10.3 mmyolo 0.6.0 model-index 0.1.11 moveit-commander 1.1.13 moveit-core 1.1.13 moveit-ros-planning-interface 1.1.13 moveit-ros-visualization 1.1.13 mpmath 1.3.0 networkx 3.1 numpy 1.24.4 nvidia-cublas-cu11 11.10.3.66 nvidia-cublas-cu12 12.1.3.1 nvidia-cuda-cupti-cu11 11.7.101 nvidia-cuda-cupti-cu12 12.1.105 nvidia-cuda-nvrtc-cu11 11.7.99 nvidia-cuda-nvrtc-cu12 12.1.105 nvidia-cuda-runtime-cu11 11.7.99 nvidia-cuda-runtime-cu12 12.1.105 nvidia-cudnn-cu11 8.5.0.96 nvidia-cudnn-cu12 8.9.2.26 nvidia-cufft-cu11 10.9.0.58 nvidia-cufft-cu12 11.0.2.54 nvidia-curand-cu11 10.2.10.91 nvidia-curand-cu12 10.3.2.106 nvidia-cusolver-cu11 11.4.0.1 nvidia-cusolver-cu12 11.4.5.107 nvidia-cusparse-cu11 11.7.4.91 nvidia-cusparse-cu12 12.1.0.106 nvidia-nccl-cu11 2.14.3 nvidia-nccl-cu12 2.18.1 nvidia-nvjitlink-cu12 12.3.101 nvidia-nvtx-cu11 11.7.91 nvidia-nvtx-cu12 12.1.105 opencv-python 4.9.0.80 opencv-python-headless 4.9.0.80 opendatalab 0.0.10 openmim 0.3.9 openxlab 0.0.34 ordered-set 4.1.0 oss2 2.17.0 packaging 23.2 pandas 2.0.3 pillow 10.2.0 pip 24.0 platformdirs 4.2.0 prettytable 3.10.0 pycocotools 2.0.7 pycparser 2.21 pycryptodome 3.20.0 Pygments 2.17.2 pyparsing 3.1.1 python-dateutil 2.8.2 python-qt-binding 0.4.4 pytz 2023.4 PyYAML 6.0.1 qt-dotgraph 0.4.2 qt-gui 0.4.2 qt-gui-cpp 0.4.2 qt-gui-py-common 0.4.2 regex 2023.12.25 requests 2.28.2 resource_retriever 1.12.7 rich 13.4.2 ros_numpy 0.0.5 rosbag 1.16.0 rosboost-cfg 1.15.8 rosclean 1.15.8 roscreate 1.15.8 rosgraph 1.16.0 roslaunch 1.16.0 roslib 1.15.8 roslint 0.12.0 roslz4 1.16.0 rosmake 1.15.8 rosmaster 1.16.0 rosmsg 1.16.0 rosnode 1.16.0 rosparam 1.16.0 rospy 1.16.0 rosservice 1.16.0 rostest 1.16.0 rostopic 1.16.0 rosunit 1.15.8 roswtf 1.16.0 rqt_action 0.4.9 rqt_bag 0.5.1 rqt_bag_plugins 0.5.1 rqt-console 0.4.12 rqt_dep 0.4.12 rqt_graph 0.4.14 rqt_gui 0.5.3 rqt_gui_py 0.5.3 rqt-image-view 0.4.17 rqt_launch 0.4.9 rqt-logger-level 0.4.12 rqt-moveit 0.5.11 rqt_msg 0.4.10 rqt_nav_view 0.5.7 rqt_plot 0.4.13 rqt_pose_view 0.5.11 rqt_publisher 0.4.10 rqt_py_common 0.5.3 rqt_py_console 0.4.10 rqt-reconfigure 0.5.5 rqt-robot-dashboard 0.5.8 rqt-robot-monitor 0.5.15 rqt_robot_steering 0.5.12 rqt-runtime-monitor 0.5.10 rqt-rviz 0.7.0 rqt_service_caller 0.4.10 rqt_shell 0.4.11 rqt_srv 0.4.9 rqt-tf-tree 0.6.4 rqt_top 0.4.10 rqt_topic 0.4.13 rqt_web 0.4.10 rviz 1.14.20 safetensors 0.4.2 scipy 1.10.0 sensor-msgs 1.13.1 setuptools 60.2.0 shapely 2.0.3 six 1.16.0 smach 2.5.2 smach-ros 2.5.2 smach-viewer 4.1.0 smclib 1.8.6 srdfdom 0.6.4 supervision 0.18.0 sympy 1.12 tabulate 0.9.0 termcolor 2.4.0 terminaltables 3.1.10 tf 1.13.2 tf-conversions 1.13.2 tf2-geometry-msgs 0.7.7 tf2-kdl 0.7.7 tf2-py 0.7.7 tf2-ros 0.7.7 tf2-sensor-msgs 0.7.7 tokenizers 0.15.2 tomli 2.0.1 topic-tools 1.16.0 torch 2.0.1 torchvision 0.15.2 tqdm 4.65.2 transformers 4.38.1 triton 2.0.0 typing_extensions 4.10.0 tzdata 2024.1 urdfdom-py 0.4.6 urllib3 1.26.18 wcwidth 0.2.13 wheel 0.42.0 xacro 1.14.16 yapf 0.40.2 yolo_world 0.1.0 /home/yangzhuo/YOLO-World zipp 3.17.0 ```
wondervictor commented 9 months ago

Hi @destroy314, it seems that the nms does not support AMP. You can cast the output tensors of the head_module to float32.

destroy314 commented 9 months ago

Still encountering CUDA out of memory, although this time the VRAM required is much less than the last time😂

Tried to allocate 5.27 GiB (GPU 0; 7.76 GiB total capacity; 1.12 GiB already allocated; 5.15 GiB free; 1.23 GiB reserved in total by PyTorch)