pytorch-labs / segment-anything-fast

A batched offline inference oriented version of segment-anything
Apache License 2.0
1.2k stars 71 forks source link

use sam_model_fast_registry on NVIDIA GeForce RTX 4090 ,get error #95

Open Aatroy opened 11 months ago

Aatroy commented 11 months ago
/opt/conda/lib/python3.8/site-packages/torch/_dynamo/utils.py:1570: UserWarning: Memory Efficient Attention requires the attn_mask to be aligned to, 8 elements. Prior to calling SDPA, pad the last dimension of the attn_mask to be at least a multiple of 8 and then slice the attn_mask to the original size. (Triggered internally at ../aten/src/ATen/native/transformers/attention.cpp:551.)
  return node.target(*args, **kwargs)
AUTOTUNE convolution(1x3x1024x1024, 1280x3x16x16)
  convolution 0.2431 ms 100.0%
  triton_convolution_3 0.5962 ms 40.8%
  triton_convolution_1 0.6139 ms 39.6%
  triton_convolution_6 0.6553 ms 37.1%
  triton_convolution_4 0.8272 ms 29.4%
  triton_convolution_5 0.9010 ms 27.0%
  triton_convolution_0 1.0113 ms 24.0%
  triton_convolution_2 3.7757 ms 6.4%
SingleProcess AUTOTUNE takes 8.4253 seconds
AUTOTUNE mm(4900x1280, 1280x3840)
  triton_mm_10 0.3042 ms 100.0%
  triton_mm_9 0.3044 ms 100.0%
  triton_mm_15 0.3175 ms 95.8%
  mm 0.3243 ms 93.8%
  triton_mm_8 0.3281 ms 92.7%
  triton_mm_11 0.3322 ms 91.6%
  triton_mm_7 0.3356 ms 90.6%
  triton_mm_14 0.3555 ms 85.6%
  triton_mm_13 0.4628 ms 65.7%
  triton_mm_12 0.4674 ms 65.1%
SingleProcess AUTOTUNE takes 9.9830 seconds
AUTOTUNE bmm(14x5600x80, 14x80x14)
  bmm 0.0276 ms 100.0%
  triton_bmm_26 0.0531 ms 51.9%
  triton_bmm_21 0.0539 ms 51.1%
  triton_bmm_20 0.0546 ms 50.5%
  triton_bmm_23 0.0552 ms 49.9%
  triton_bmm_24 0.0559 ms 49.3%
  triton_bmm_30 0.0561 ms 49.1%
  triton_bmm_28 0.0570 ms 48.3%
  triton_bmm_29 0.0576 ms 47.9%
  triton_bmm_22 0.0577 ms 47.8%
SingleProcess AUTOTUNE takes 9.0480 seconds
AUTOTUNE bmm(14x5600x80, 14x80x14)
  bmm 0.0295 ms 100.0%
  triton_bmm_36 0.0546 ms 54.1%
  triton_bmm_39 0.0550 ms 53.7%
  triton_bmm_35 0.0558 ms 52.9%
  triton_bmm_33 0.0558 ms 52.9%
  triton_bmm_32 0.0559 ms 52.8%
  triton_bmm_41 0.0563 ms 52.4%
  triton_bmm_31 0.0563 ms 52.4%
  triton_bmm_42 0.0564 ms 52.3%
  triton_bmm_40 0.0567 ms 52.1%
SingleProcess AUTOTUNE takes 6.5209 seconds
AUTOTUNE mm(4900x1280, 1280x1280)
  triton_mm_44 0.1199 ms 100.0%
  triton_mm_51 0.1242 ms 96.5%
  triton_mm_46 0.1272 ms 94.2%
  triton_mm_47 0.1283 ms 93.4%
  triton_mm_45 0.1284 ms 93.4%
  triton_mm_43 0.1305 ms 91.9%
  mm 0.1467 ms 81.8%
  triton_mm_50 0.1582 ms 75.8%
  triton_mm_49 0.1715 ms 69.9%
  triton_mm_48 0.1749 ms 68.5%
SingleProcess AUTOTUNE takes 7.7492 seconds
AUTOTUNE mm(4096x1280, 1280x5120)
  triton_mm_56 0.3222 ms 100.0%
  triton_mm_58 0.3230 ms 99.8%
  triton_mm_57 0.3249 ms 99.2%
  triton_mm_59 0.3261 ms 98.8%
  mm 0.3281 ms 98.2%
  triton_mm_63 0.3495 ms 92.2%
  triton_mm_55 0.3971 ms 81.1%
  triton_mm_62 0.4037 ms 79.8%
  triton_mm_61 0.5124 ms 62.9%
  triton_mm_60 0.5172 ms 62.3%
SingleProcess AUTOTUNE takes 10.1354 seconds
AUTOTUNE mm(4096x5120, 5120x1280)
  triton_mm_70 0.3214 ms 100.0%
  triton_mm_71 0.3223 ms 99.7%
  mm 0.3253 ms 98.8%
  triton_mm_69 0.3438 ms 93.5%
  triton_mm_68 0.3445 ms 93.3%
  triton_mm_75 0.3832 ms 83.9%
  triton_mm_67 0.3880 ms 82.8%
  triton_mm_74 0.4686 ms 68.6%
  triton_mm_73 0.5219 ms 61.6%
  triton_mm_72 0.5222 ms 61.5%
SingleProcess AUTOTUNE takes 9.7927 seconds
AUTOTUNE addmm(4096x3840, 4096x1280, 1280x3840)
  triton_mm_514 0.2487 ms 100.0%
  triton_mm_513 0.2495 ms 99.7%
  triton_mm_515 0.2504 ms 99.3%
  triton_mm_512 0.2662 ms 93.4%
  triton_mm_519 0.2712 ms 91.7%
  bias_addmm 0.2847 ms 87.4%
  triton_mm_518 0.2966 ms 83.9%
  triton_mm_511 0.2989 ms 83.2%
  addmm 0.3052 ms 81.5%
  triton_mm_516 0.3908 ms 63.7%
SingleProcess AUTOTUNE takes 9.8966 seconds
AUTOTUNE bmm(64x1024x80, 64x80x64)
  triton_bmm_531 0.0304 ms 100.0%
  triton_bmm_532 0.0323 ms 94.2%
  triton_bmm_529 0.0324 ms 93.7%
  triton_bmm_528 0.0327 ms 93.0%
  bmm 0.0328 ms 92.7%
  triton_bmm_524 0.0329 ms 92.4%
  triton_bmm_523 0.0332 ms 91.7%
  triton_bmm_525 0.0335 ms 90.8%
  triton_bmm_530 0.0336 ms 90.5%
  triton_bmm_526 0.0347 ms 87.6%
SingleProcess AUTOTUNE takes 7.9944 seconds
AUTOTUNE mm(4096x1280, 1280x1280)
  triton_mm_550 0.0950 ms 100.0%
  triton_mm_551 0.0951 ms 99.8%
  triton_mm_549 0.0982 ms 96.7%
  triton_mm_548 0.0994 ms 95.6%
  triton_mm_555 0.1012 ms 93.9%
  triton_mm_547 0.1047 ms 90.7%
  mm 0.1159 ms 81.9%
  triton_mm_554 0.1332 ms 71.3%
  triton_mm_553 0.1449 ms 65.6%
  triton_mm_552 0.1482 ms 64.1%
SingleProcess AUTOTUNE takes 9.5205 seconds
AUTOTUNE convolution(1x1280x64x64, 256x1280x1x1)
  convolution 0.0409 ms 100.0%
  triton_convolution_2315 0.0632 ms 64.7%
  triton_convolution_2317 0.0963 ms 42.5%
  triton_convolution_2314 0.0977 ms 41.8%
  triton_convolution_2312 0.1126 ms 36.3%
  conv1x1_via_mm 0.1306 ms 31.3%
  triton_convolution_2316 0.1394 ms 29.3%
  triton_convolution_2313 0.1460 ms 28.0%
  triton_convolution_2311 0.1762 ms 23.2%
SingleProcess AUTOTUNE takes 6.8541 seconds
AUTOTUNE convolution(1x256x64x64, 256x256x3x3)
  convolution 0.0861 ms 100.0%
  triton_convolution_2324 0.1452 ms 59.3%
  triton_convolution_2319 0.1805 ms 47.7%
  triton_convolution_2322 0.1813 ms 47.5%
  triton_convolution_2321 0.1835 ms 46.9%
  triton_convolution_2323 0.2455 ms 35.1%
  triton_convolution_2320 0.2628 ms 32.8%
  triton_convolution_2318 0.3300 ms 26.1%
SingleProcess AUTOTUNE takes 8.1397 seconds
Traceback (most recent call last):
  File "server/main.py", line 36, in sam
    ans = pth_process(info, image)
  File "/sam/./scripts/pth_model.py", line 143, in process
    masks, scores, logits = predictor.infer(image,input_point,input_label,input_box,pmask)
  File "/sam/./scripts/pth_model.py", line 20, in infer
    return self.predictor.infer(input_image, input_point, input_label,input_box,input_mask)
  File "/sam/./sam_fast/segment_anything_fast/predictor.py", line 282, in infer
    return self.predict(
  File "/sam/./sam_fast/segment_anything_fast/predictor.py", line 165, in predict
    iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
TypeError: Got unsupported ScalarType BFloat16

I use vit_h and code below: sam_fast is segment-anything-fast:main-branch

       from sam_fast.segment_anything_fast import sam_model_registry,sam_model_fast_registry, SamPredictor
        sam = sam_model_fast_registry[model_type](checkpoint=checkpoint)
        sam.to(device=config["device"])
        predictor= SamPredictor(sam)
        masks, scores, logits = predictor.infer(image,input_point,input_label,input_box,pmask)

     predictor .infer
     def infer(self,input_image, input_point, input_label,input_box,input_mask):
      self.set_image(input_image)
      return self.predict(
            point_coords=input_point,
            point_labels=input_label,
            box=input_box,
            multimask_output=False,
            mask_input=input_mask
        )
Aatroy commented 11 months ago

Perhaps it is the inconsistency of the environment that causes the error. Could you please provide a Dockerfile ?

mawanda-jun commented 11 months ago

Hi, I had the same problem. It is related to the (not) automatic conversion that occurs when you cast the bfloat16 torch tensor to the float32 numpy array. You should replace this line: iou_predictions_np = iou_predictions[0].detach().cpu().numpy() with: iou_predictions_np = iou_predictions[0].detach().cpu().float().numpy() so that you force torch to explitly convert the bfloat16 tensor to a float32 one.

Be aware not to use .half() instead of .float() since you might risk under/overflow (the representation limits of bfloat16 are the same of the float32, so much more range than fp16 but less precision).

EDIT: Refer to this pull request I just made

cpuhrsch commented 11 months ago

I just merged @mawanda-jun's pull request. Please try again and let me know if it doesn't work :) Thank you for your interest in the project!