Closed 525753936 closed 1 year ago
My bad, here is the code.
import os import cv2 import rospy from PIL import Image from sensor_msgs.msg import Image as ROSImage from cv_bridge import CvBridge, CvBridgeError from SegTracker import SegTracker from model_args import aot_args,sam_args,segtracker_args from aot_tracker import _palette import numpy as np import torch from scipy.ndimage import binary_dilation import gc
bridge = CvBridge() image_received = False cv_image = None
def image_callback(msg): global bridge global image_received global cv_image try: cv_image = bridge.imgmsg_to_cv2(msg, "bgr8") image_received = True except CvBridgeError as e: print(e)
rospy.init_node('zed2_image_listener', anonymous=True)
image_sub = rospy.Subscriber("/camera/rgb/image_raw", ROSImage, image_callback)
def save_prediction(pred_mask,output_dir,file_name): save_mask = Image.fromarray(pred_mask.astype(np.uint8)) save_mask = save_mask.convert(mode='P') save_mask.putpalette(_palette) save_mask.save(os.path.join(output_dir,file_name))
def colorize_mask(pred_mask): save_mask = Image.fromarray(pred_mask.astype(np.uint8)) save_mask = save_mask.convert(mode='P') save_mask.putpalette(_palette) save_mask = save_mask.convert(mode='RGB') return np.array(save_mask)
def draw_mask(img, mask, alpha=0.5, id_countour=False): img_mask = np.zeros_like(img) img_mask = img if id_countour:
obj_ids = np.unique(mask)
obj_ids = obj_ids[obj_ids!=0]
for id in obj_ids:
# Overlay color on binary mask
if id <= 255:
color = _palette[id*3:id*3+3]
else:
color = [0,0,0]
foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)
binary_mask = (mask == id)
# Compose image
img_mask[binary_mask] = foreground[binary_mask]
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
img_mask[countours, :] = 0
else:
binary_mask = (mask!=0)
countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
foreground = img*(1-alpha)+colorize_mask(mask)*alpha
img_mask[binary_mask] = foreground[binary_mask]
img_mask[countours,:] = 0
return img_mask.astype(img.dtype)
segtracker_args = { 'sam_gap': 5, 'min_area': 200, 'max_obj_num': 255, 'min_new_obj_iou': 0.8, }
frame_idx = 0 segtracker = SegTracker(segtracker_args,sam_args,aot_args) segtracker.restart_tracker()
torch.cuda.empty_cache() gc.collect() sam_gap = segtracker_args['sam_gap']
with torch.cuda.amp.autocast(): while not rospy.is_shutdown(): if image_received: frame = cv_image frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) if frame_idx == 0: pred_mask = segtracker.seg(frame) torch.cuda.empty_cache() gc.collect() segtracker.add_reference(frame, pred_mask) elif (frame_idx % sam_gap) == 0: seg_mask = segtracker.seg(frame) torch.cuda.empty_cache() gc.collect() track_mask = segtracker.track(frame) new_obj_mask = segtracker.find_new_objs(track_mask,seg_mask) pred_mask = track_mask + new_obj_mask segtracker.add_reference(frame, pred_mask) else: pred_mask = segtracker.track(frame,update_memory=True) torch.cuda.empty_cache() gc.collect()
# Draw the mask on the frame
frame_with_mask = draw_mask(frame, pred_mask)
# Display the processed frame with segmentation mask
cv2.imshow('Processed Frame', frame_with_mask)
cv2.waitKey(1)
print("processed frame {}, obj_num {}".format(frame_idx,segtracker.get_obj_num()),end='\r')
frame_idx += 1
print('\nfinished')
cv2.destroyAllWindows()
del segtracker torch.cuda.empty_cache() gc.collect()
Thank you for your question. We recommend increasing the "sam_gap" parameter, for example to 99999, to adapt to real-time scenarios. To track new objects, our model uses the SAM for segmentation every "sam_gap" frames and adds new objects. This is why the processing time increases every "sam_gap" frames. If you have any further questions or concerns, please do not hesitate to let us know.
Appreciate your patience and responds, it worked at last.
Thanks.
***@***.***
---- Replied Message ----
From
***@***.***>
Date
7/8/2023 11:10
To
***@***.***>
Cc
yuyang ***@***.***>
,
***@***.***>
Subject
Re: [z-x-yang/Segment-and-Track-Anything] Questions about running camera in real time (Issue #74)
Thank you for your question. We recommend increasing the "sam_gap" parameter, for example to 99999, to adapt to real-time scenarios. To track new objects, our model uses the SAM for segmentation every "sam_gap" frames and adds new objects. This is why the processing time increases every "sam_gap" frames. If you have any further questions or concerns, please do not hesitate to let us know.
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>
hello,
Hope everything finds you well. Thanks for sharing your amazing work! I got some questions with your code. Cause I am not familiar with the principal of the segmentation, and i just wanna use it for my own application, i just noticed that there was a parameter called "sam_gap " in the "demo.ipynb " file , which seemingly representes a new pred_mask will be generated. What confused me was that when i tried to run the modified code on the camera in real time, it seemed that it would stuck for a few seconds, after processing 4 frames. Can I just change this number for higher or keep it the same? I wonder is there any technique to run it in real time? you can check the modified code anyway. Here is modified code : /home/zhengchen/Segment-and-Track-Anything-main/test2.py Thanks for your marvelous work again!