hustvl / SparseInst

[CVPR 2022] SparseInst: Sparse Instance Activation for Real-Time Instance Segmentation
MIT License
590 stars 71 forks source link

OOM when running test_net.py and demo.py #71

Open siddagra opened 2 years ago

siddagra commented 2 years ago

How much memory does this use while inferencing? It is trying to allocate 12GB and there is no option for changing batch sizes in these scripts either (test_net.py and demo.py). I changed batch size in the config but still same result.

I ran this on an RTX3080:

python3 test_net.py --config-file configs/sparse_inst_r50_giam.yaml --num-gpus 1 MODEL.WEIGHTS output/sparse_inst_r50vd_dcn_giam_aug/model_0006499.pth INPUT.MIN_SIZE_TEST 512

Error:

[08/06 12:41:22 d2.data.dataset_mapper]: [DatasetMapper] Augmentations used in inference: [ResizeShortestEdge(short_edge_length=(512, 512), max_size=853, sample_style='choice')]
[08/06 12:41:22 d2.data.common]: Serializing 30 elements to byte tensors and concatenating them all ...
[08/06 12:41:22 d2.data.common]: Serialized dataset takes 2.45 MiB
/home/siddharth/.local/lib/python3.8/site-packages/detectron2/structures/image_list.py:88: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  max_size = (max_size + (stride - 1)) // stride * stride
/home/siddharth/.local/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at  ../aten/src/ATen/native/TensorShape.cpp:2157.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
Traceback (most recent call last):
  File "test_net.py", line 196, in <module>
    test_sparseinst_speed(cfg, fp16=args.fp16)
  File "test_net.py", line 157, in test_sparseinst_speed
    output = model(images, resized_size, ori_size)
  File "/home/siddharth/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "test_net.py", line 82, in forward
    result = self.inference_single(
  File "test_net.py", line 116, in inference_single
    pred_masks = F.interpolate(pred_masks, size=ori_shape, mode='bilinear',
  File "/home/siddharth/.local/lib/python3.8/site-packages/torch/nn/functional.py", line 3731, in interpolate
    return torch._C._nn.upsample_bilinear2d(input, output_size, align_corners, scale_factors)
RuntimeError: CUDA out of memory. Tried to allocate 11.85 GiB (GPU 0; 9.78 GiB total capacity; 431.67 MiB already allocated; 6.52 GiB free; 462.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
siddagra commented 2 years ago

I think the size of the input image is very big and when it is trying to upsample, it crashes, fixing this is going to be a pain since I will need to resize the coco annotations as well. Any way to just evaluate on resized images (along with resized annotations)?

Also, I am also getting this error: AssertionError: A prediction has class=2, but the dataset only has 2 classes and predicted class id should be in [0, 1].

However, when I set my NUM_CLASSES to 2 while training, I get CUDA error: device-side assert triggered