SysCV / sam-hq

Segment Anything in High Quality [NeurIPS 2023]
https://arxiv.org/abs/2306.01567
Apache License 2.0
3.73k stars 224 forks source link

Why does my implementation of sam-hq light takes longer time? #64

Open mickyzyf opened 1 year ago

mickyzyf commented 1 year ago

`image = cv2.imread('demo/input_imgs/71fb2a7c-a0cd-4908-b2b6-672e68ca583a.webp') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) axs[0].imshow(image) plt.axis('off')

sam_checkpoint = "pretrained/sam_hq_vit_tiny.pth" model_type = "vit_tiny" device = "cuda:5" sam = sam_model_registrymodel_type sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator( model=sam, )

import time start = time.time() masks = mask_generator.generate(image) print('mask1: {}'.format(time.time() - start)) print(f'len of mask {len(masks)}')

axs[1].imshow(image) show_anns(masks, axs[1]) plt.axis('off')

torch.cuda.empty_cache()

sam_checkpoint = "/mnt/model/pretrained/sam_hq_vit_h.pth" model_type = "vit_h" device = "cuda:7" sam = sam_model_registrymodel_type sam.to(device=device)

mask_generator2 = SamAutomaticMaskGenerator( model=sam, )

start = time.time() masks2 = mask_generator2.generate(image, multimask_output=False) print('mask2: {}'.format(time.time() - start)) print(f'len of mask {len(masks2)}')

axs[2].imshow(image) show_anns(masks2, axs[2]) plt.axis('off')

fig.suptitle("3x1 subplot demo") plt.tight_layout() plt.savefig("demo/outputs_imgs/subplot_demo.png") ` The output is mask1: 5.129418849945068 len of mask 14

mask2: 5.125577211380005 len of mask 14 I would like to get all segmentations from an image using sam-hq. I do expect light model be much faster. However, it takes similar amount of time as sam_hq_vit_h on a few images I tested. Does anyone have a clue why?
lkeab commented 1 year ago

hi, can you remove the post-processing steps (used in SAM's anything mode) during inference, and report the time for both vit-h and vit-tiny based hq-sam again? thanks.