facebookresearch / segment-anything

The repository provides code for running inference with the SegmentAnything Model (SAM), links for downloading the trained model checkpoints, and example notebooks that show how to use the model.
Apache License 2.0
47.81k stars 5.65k forks source link

How to make sam( vit_l onnx ) faster. #766

Open Siwakonrome opened 5 months ago

Siwakonrome commented 5 months ago

My Configuration

  1. Window OS
  2. vit_l onnx
  3. RTX 2080

My processing times are 10.1 sec. I want it faster under 1.0 or 2.0 seconds to be appropriate with my application.

def __call__(self, image_np, box):
        t0 = time.time()
        box = self.create_box_as_a_prompt(box=box)
        input_tensor, resized, orig = self.preprocess(image_np=image_np)
        contours, hierarchy = self.postprocess(input_tensor=input_tensor,
                                               box=box,
                                               resized=resized,
                                               orig=orig)
        t1 = time.time()
        process_time = t1 - t0
        return contours, hierarchy, process_time

My code.

import cv2
import time
import numpy as np
from PIL import Image
import onnxruntime as ort
from copy import deepcopy

class SegmentSamOnnxOperator:

    def __init__(self, onnx_encoder_path, onnx_decoder_path):
        self.encoder = ort.InferenceSession(onnx_encoder_path)
        self.decoder = ort.InferenceSession(onnx_decoder_path)

    def __call__(self, image_np, box):
        t0 = time.time()
        box = self.create_box_as_a_prompt(box=box)
        input_tensor, resized, orig = self.preprocess(image_np=image_np)
        contours, hierarchy = self.postprocess(input_tensor=input_tensor,
                                               box=box,
                                               resized=resized,
                                               orig=orig)
        t1 = time.time()
        process_time = t1 - t0
        return contours, hierarchy, process_time

    def create_box_as_a_prompt(self, box):
        return np.array([box['x'], box['y'], box['x'] + box['width'], box['y'] + box['height']])

    def postprocess(self, input_tensor, box, resized, orig):
        orig_width , orig_height = orig
        resized_width , resized_height = resized
        outputs = self.encoder.run(None,{"images":input_tensor})
        embeddings = outputs[0]
        # 3. DECODE MASKS FROM IMAGE EMBEDDINGS
        # 3.2 OPTION 2: Use box as a prompt
        # ENCODE PROMPT (box)
        input_box = box.reshape(2,2)
        input_labels = np.array([2,3])
        onnx_coord = input_box[None, :, :]
        onnx_label = input_labels[None, :].astype(np.float32)
        coords = deepcopy(onnx_coord).astype(float)
        coords[..., 0] = coords[..., 0] * (resized_width / orig_width)
        coords[..., 1] = coords[..., 1] * (resized_height / orig_height)
        onnx_coord = coords.astype("float32")
        # RUN DECODER TO GET MASK
        onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
        onnx_has_mask_input = np.zeros(1, dtype=np.float32)
        masks,_,_ = self.decoder.run(None,{
            "image_embeddings": embeddings,
            "point_coords": onnx_coord,
            "point_labels": onnx_label,
            "mask_input": onnx_mask_input,
            "has_mask_input": onnx_has_mask_input,
            "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
        })
        # POSTPROCESS MASK
        mask = masks[0][0]
        mask = (mask > 0).astype('uint8')*255
        # MASK to contours
        img_mask = Image.fromarray(mask, "L").convert("RGB")
        imgray_np = cv2.cvtColor(np.array(img_mask), cv2.COLOR_RGB2GRAY)
        ret, thresh = cv2.threshold(imgray_np, 250, 255, 0)
        contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
        return contours, hierarchy

    def preprocess(self, image_np):
        img = Image.fromarray(cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB))
        # 1. PREPROCESS IMAGE FOR ENCODER
        # Resize image preserving aspect ratio using 1024 as a long side
        orig_width, orig_height = img.size
        resized_width, resized_height = img.size
        if orig_width > orig_height:
            resized_width = 1024
            resized_height = int(1024 / orig_width * orig_height)
        else:
            resized_height = 1024
            resized_width = int(1024 / orig_height * orig_width)
        img = img.resize((resized_width, resized_height), Image.Resampling.BILINEAR)
        # Prepare input tensor from image
        input_tensor = np.array(img)
        # Normalize input tensor numbers
        mean = np.array([123.675, 116.28, 103.53])
        std = np.array([[58.395, 57.12, 57.375]])
        input_tensor = (input_tensor - mean) / std
        # Transpose input tensor to shape (Batch,Channels,Height,Width
        input_tensor = input_tensor.transpose(2,0,1)[None,:,:,:].astype(np.float32)
        # Make image square 1024x1024 by padding short side by zeros
        if resized_height < resized_width:
            input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,1024-resized_height),(0,0)))
        else:
            input_tensor = np.pad(input_tensor,((0,0),(0,0),(0,0),(0,1024-resized_width)))
        return input_tensor, (resized_width , resized_height), (orig_width , orig_height)