z-x-yang / Segment-and-Track-Anything

An open-source project dedicated to tracking and segmenting any objects in videos, either automatically or interactively. The primary algorithms utilized include the Segment Anything Model (SAM) for key-frame segmentation and Associating Objects with Transformers (AOT) for efficient tracking and propagation purposes.
GNU Affero General Public License v3.0
2.77k stars 334 forks source link

Questions about running camera in real time #74

Closed 525753936 closed 1 year ago

525753936 commented 1 year ago

hello,
Hope everything finds you well. Thanks for sharing your amazing work! I got some questions with your code. Cause I am not familiar with the principal of the segmentation, and i just wanna use it for my own application, i just noticed that there was a parameter called "sam_gap " in the "demo.ipynb " file , which seemingly representes a new pred_mask will be generated. What confused me was that when i tried to run the modified code on the camera in real time, it seemed that it would stuck for a few seconds, after processing 4 frames. Can I just change this number for higher or keep it the same? I wonder is there any technique to run it in real time? you can check the modified code anyway. Here is modified code : /home/zhengchen/Segment-and-Track-Anything-main/test2.py Thanks for your marvelous work again!

525753936 commented 1 year ago

My bad, here is the code.

import os import cv2 import rospy from PIL import Image from sensor_msgs.msg import Image as ROSImage from cv_bridge import CvBridge, CvBridgeError from SegTracker import SegTracker from model_args import aot_args,sam_args,segtracker_args from aot_tracker import _palette import numpy as np import torch from scipy.ndimage import binary_dilation import gc

Add ROS related variables and initialization

bridge = CvBridge() image_received = False cv_image = None

def image_callback(msg): global bridge global image_received global cv_image try: cv_image = bridge.imgmsg_to_cv2(msg, "bgr8") image_received = True except CvBridgeError as e: print(e)

Initialize ROS node

rospy.init_node('zed2_image_listener', anonymous=True)

Create a subscriber

image_sub = rospy.Subscriber("/camera/rgb/image_raw", ROSImage, image_callback)

Segmentation related functions remain the same...

def save_prediction(pred_mask,output_dir,file_name): save_mask = Image.fromarray(pred_mask.astype(np.uint8)) save_mask = save_mask.convert(mode='P') save_mask.putpalette(_palette) save_mask.save(os.path.join(output_dir,file_name))

def colorize_mask(pred_mask): save_mask = Image.fromarray(pred_mask.astype(np.uint8)) save_mask = save_mask.convert(mode='P') save_mask.putpalette(_palette) save_mask = save_mask.convert(mode='RGB') return np.array(save_mask)

def draw_mask(img, mask, alpha=0.5, id_countour=False): img_mask = np.zeros_like(img) img_mask = img if id_countour:

very slow ~ 1s per image

    obj_ids = np.unique(mask)
    obj_ids = obj_ids[obj_ids!=0]

    for id in obj_ids:
        # Overlay color on  binary mask
        if id <= 255:
            color = _palette[id*3:id*3+3]
        else:
            color = [0,0,0]
        foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)
        binary_mask = (mask == id)

        # Compose image
        img_mask[binary_mask] = foreground[binary_mask]

        countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
        img_mask[countours, :] = 0
else:
    binary_mask = (mask!=0)
    countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
    foreground = img*(1-alpha)+colorize_mask(mask)*alpha
    img_mask[binary_mask] = foreground[binary_mask]
    img_mask[countours,:] = 0

return img_mask.astype(img.dtype)

segtracker_args = { 'sam_gap': 5, 'min_area': 200, 'max_obj_num': 255, 'min_new_obj_iou': 0.8, }

frame_idx = 0 segtracker = SegTracker(segtracker_args,sam_args,aot_args) segtracker.restart_tracker()

torch.cuda.empty_cache() gc.collect() sam_gap = segtracker_args['sam_gap']

with torch.cuda.amp.autocast(): while not rospy.is_shutdown(): if image_received: frame = cv_image frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) if frame_idx == 0: pred_mask = segtracker.seg(frame) torch.cuda.empty_cache() gc.collect() segtracker.add_reference(frame, pred_mask) elif (frame_idx % sam_gap) == 0: seg_mask = segtracker.seg(frame) torch.cuda.empty_cache() gc.collect() track_mask = segtracker.track(frame) new_obj_mask = segtracker.find_new_objs(track_mask,seg_mask) pred_mask = track_mask + new_obj_mask segtracker.add_reference(frame, pred_mask) else: pred_mask = segtracker.track(frame,update_memory=True) torch.cuda.empty_cache() gc.collect()

        # Draw the mask on the frame
        frame_with_mask = draw_mask(frame, pred_mask)

        # Display the processed frame with segmentation mask
        cv2.imshow('Processed Frame', frame_with_mask)
        cv2.waitKey(1)

        print("processed frame {}, obj_num {}".format(frame_idx,segtracker.get_obj_num()),end='\r')
        frame_idx += 1

print('\nfinished')

cv2.destroyAllWindows()

del segtracker torch.cuda.empty_cache() gc.collect()

LiNO3Dy commented 1 year ago

Thank you for your question. We recommend increasing the "sam_gap" parameter, for example to 99999, to adapt to real-time scenarios. To track new objects, our model uses the SAM for segmentation every "sam_gap" frames and adds new objects. This is why the processing time increases every "sam_gap" frames. If you have any further questions or concerns, please do not hesitate to let us know.

525753936 commented 1 year ago
        Appreciate your patience and responds, it worked at last.

    Thanks.

                    ***@***.***

---- Replied Message ----

     From 

        ***@***.***>

     Date 

    7/8/2023 11:10

     To 

        ***@***.***>

     Cc 

        yuyang ***@***.***>
        ,

        ***@***.***>

     Subject 

          Re: [z-x-yang/Segment-and-Track-Anything] Questions about running camera in real time (Issue #74)

Thank you for your question. We recommend increasing the "sam_gap" parameter, for example to 99999, to adapt to real-time scenarios. To track new objects, our model uses the SAM for segmentation every "sam_gap" frames and adds new objects. This is why the processing time increases every "sam_gap" frames. If you have any further questions or concerns, please do not hesitate to let us know.

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>