CASIA-IVA-Lab / FastSAM

Fast Segment Anything
GNU Affero General Public License v3.0
7.35k stars 687 forks source link

Output from onnx format #53

Closed okdha1234 closed 1 year ago

okdha1234 commented 1 year ago

Hi, output from onnx format is of shape ([(1, 37, 21504), (1, 32, 256, 256)]. If I post process them using below method where conf = 0.4, iou -> 0,=.9, and agnostic_nms = False like in the FastSAM .pt model but it doesn't return masks fo same length.

Can someone explain the outputs from onnx format FastSAM model and how to postprocess them . def postprocess(preds, conf, iou, agnostic_nms=False): """TODO: filter by classes."""

p = ops.non_max_suppression(preds[0],
                            conf,
                            iou,
                            agnostic_nms,
                            max_det=100,
                            nc=1)

results = []
proto = preds[1]  # second output is len 3 if pt, but only 1 if exported
for i, pred in enumerate(p):

    pred[:, :4] = ops.scale_boxes(torch.Size([1024,1024]), pred[:, :4],(1024,1024))
    masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4],(1024,1024))  # HWC
    return masks
berry-ding commented 1 year ago

Hi @okdha1234 , Hello, YOLOv8 already supports inference with ONNX models. You can refer to the code and simply replace FastSAM.pt with FastSAM.onnx.

from fastsam import FastSAM, FastSAMPrompt

model = FastSAM('./weights/FastSAM.pt')
IMAGE_PATH = './images/dogs.jpg'
DEVICE = 'cpu'
everything_results = model(IMAGE_PATH, device=DEVICE, retina_masks=True, imgsz=1024, conf=0.4, iou=0.9,)
prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE)

# everything prompt
ann = prompt_process.everything_prompt()

# bbox default shape [0,0,0,0] -> [x1,y1,x2,y2]
ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300])

# text prompt
ann = prompt_process.text_prompt(text='a photo of a dog')

# point prompt
# points default [[0,0]] [[x1,y1],[x2,y2]]
# point_label default [0] [1,0] 0:background, 1:foreground
ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1])

prompt_process.plot(annotations=ann,output='./output/',)

If you wish to customize ONNX inference, please refer to fastsam/predict.py for post-processing.

okdha1234 commented 1 year ago

@berry-ding thanks.