facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47k stars 5.56k forks source link

computational complexity #630

Open yejing0 opened 10 months ago

yejing0 commented 10 months ago

Very excellent work, but we would like to further evaluate the efficacy of the model. How do we test the FLOPs of the model?

heyoeyo commented 10 months ago

I'm not sure if there is a perfect way to count FLOPs. One option is to use some software that estimates the count from the model definition. Searching around, it looks like facebookresearch has another repo (fvcore) which contains this functionality, which you may want to try.

The other option is to manually estimate the count from the model definition, which is easier said than done! That being said, the bulk of the SAM model is transformer-based, so if you estimate the FLOP count for a single transformer block of the image encoder and the mask decoder, you could probably get a good approximation of the total FLOP count just by multiplying by the number of transformer blocks in each component of the model (the image encoder and mask decoder sizing configs can both be found in the build_sam.py script).

BeliverK commented 3 months ago

我不确定是否有一种完美的方法来计算 FLOP。一种选择是使用一些软件来估计模型定义中的计数。四处搜索,看起来 facebookresearch 有另一个包含此功能的存储库 (fvcore),您可能想尝试一下。

另一种选择是从模型定义中手动估计计数,这说起来容易做起来难!话虽如此,SAM 模型的大部分是基于转换器的,因此,如果您估计图像编码器掩码解码器的单个转换器块的 FLOP 计数,则只需乘以模型每个组件中的转换器块数量(图像编码器和掩码器大小配置都可以在build_sam.py脚本中找到),就可以获得总 FLOP 计数的近似值。 Very excellent work! I would like to ask, have you tried to use fvcore to calculate FLOPs, what parts are included in your calculation, I did not succeed in the calculation, I would like to ask if you have a calculation code

BeliverK commented 3 months ago

from calflops import calculate_flops from torchvision import models import requests from PIL import Image from transformers import SamModel, SamProcessor import torch device = "cuda" if torch.cuda.is_available() else "cpu" model = SamModel.from_pretrained("./sam-vitb") processor = SamProcessor.from_pretrained("./sam-vitb")

img_pth = "./car.png" raw_image = Image.open(img_pth).convert("RGB") input_points = [[[450, 600,2000,1000], ]] inputs = processor(raw_image, input_boxes=input_points, return_tensors="pt") outputs = model(**inputs) masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) flops, macs, params = calculate_flops(model=model, kwargs=inputs,
print_results=True) print("Salesforce/blip2-opt-2.7b FLOPs:%s MACs:%s Params:%s \n" % (flops, macs, params)) I used calflops to test lSAM-Base and SAM-Large, the code is as shown in the image, I didn't test the FLOPs of SAM-H due to memory limitations。But I now want to test FLOPs for MobileSAM and it hasn't worked out yet