hkchengrex / Cutie

[CVPR 2024 Highlight] Putting the Object Back Into Video Object Segmentation
https://hkchengrex.com/Cutie/
MIT License
579 stars 60 forks source link

Trying to reproduce SAMTrack-esque demo with Cutie #51

Closed vineetparikh closed 3 months ago

vineetparikh commented 3 months ago

Hi there, thanks so much for making this! This is really cool!

I'm trying to adapt this to a SAMTrack-like demo (similar to https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/demo_instseg.ipynb) where I have an existing image-level segmentation system which can give me masks for the objects I want to track, and I want to be able to run this system at intervals over the video to start tracking new masks as they come into frame. However, when implementing something like this I'm running into errors with torch transformations (in this case, in the self.sensory_update call where during chunk-by-chunk inference in Mask Encoder, there's an issue with concatenating the inputs "g" and "h" as they are different sizes).

Is there a way to adapt the current demo to follow this sort of "online segmentation and tracking" framework (assuming I have a separate source for ID-based masks so I just need to somehow get tracked masks without state from Cutie along with merging the tracked and segmented masks and updating the Cutie memory)?

hkchengrex commented 3 months ago

Hello,

We have a separate project DEVA for this type of application. Technically Cutie can be extended in the same way (people have done so), but it is not implemented here.

if the sensory memory has a different size with g, it might be that you forgot to initialize sensory memory for new objects when they are added.

vineetparikh commented 3 months ago

Hi, thanks for following up!

In that case, I guess I'm a bit confused as to how to initialize sensory memory for new objects since based on my understanding of the inference engine it looks like simply adding a new mask and object list should be enough. Here's the code I have right now for merging segmentation and tracking masks:

    def merge_seg_and_track_masks(
        self, seg_mask: Image, track_mask: Image, frame: Image
    ):
        """
        Finds new objects between the h23 and cutie masks, updates tracker, returns merged mask
        """
        track_mask_np = np.array(track_mask)
        seg_mask_np = np.array(seg_mask)
        new_obj_mask = (track_mask_np == 0) * seg_mask_np
        new_obj_ids = np.unique(new_obj_mask)
        new_obj_ids = new_obj_ids[new_obj_ids != 0]
        print("New obj ids: " + str(new_obj_ids))
        # obj_num = self.get_obj_num() + 1
        obj_num = self.curr_idx
        for idx in new_obj_ids:
            new_obj_area = np.sum(new_obj_mask == idx)
            obj_area = np.sum(seg_mask_np == idx)
            if (
                new_obj_area / obj_area < self.min_new_obj_iou
                or new_obj_area < self.min_area
                or obj_num > self.max_obj_num
            ):
                # print("Not adding new obj: " + str(idx))
                new_obj_mask[new_obj_mask == idx] = 0
            else:
                # print("Adding new obj: " + str(idx))
                new_obj_mask[new_obj_mask == idx] = obj_num
                # Add the mask to the reference objects
                # self.add_reference(
                #     frame=frame, mask=new_obj_mask == obj_num
                # )
                obj_num += 1
        if np.sum(new_obj_mask > 0) > frame.size[0] * frame.size[1] * 0.4:
            # print("Too many new objects, not adding any")
            new_obj_mask = np.zeros_like(new_obj_mask)
        else:
            # print("Adding new objects to tracker")
            self.curr_idx = obj_num
        full_mask = track_mask_np + new_obj_mask
        # If the full mask is 0, then there's no new objects to add
        if np.sum(full_mask) == 0:
            return track_mask  # return the original track mask since it's also 0!

        # Get the set of objects from new_obj_mask
        object_list = np.unique(full_mask)
        # remove the 0 from the list
        object_list = object_list[object_list != 0].tolist()
        fmt = torch.from_numpy(full_mask.astype(np.uint8)).cuda()
        ft = to_tensor(frame).cuda().float()
        print("FT shape: ", (ft.shape))
        print("FMT shape: ", (fmt.shape))
        print("FMT stuff: ", np.unique(fmt.cpu().numpy()))
        print("Object list: ", object_list)
        ctm = self.cutie_tracker.step(ft, fmt, objects=object_list)
        ctm = torch.argmax(ctm, dim=0).detach().cpu().numpy().astype(np.uint8)

        return ctm

The track mask is obtained by calling the step function statelessly (i.e. setting End to True so memory isn't updated before we add new objects). However, in situations where objects go out of frame but new objects aren't added, the g value has a different shape compared to h because g contains 3 objects and h contains 1 (and actually in the videos I'm testing there are lots of cases where objects that we want to track appear and then disappear throughout the video).

The two main questions I have are:

Thanks so much for your help!

hkchengrex commented 3 months ago

Thanks for the follow-up. I'll look into this.

hkchengrex commented 3 months ago

Hello, I have tested adding/removing objects and they should all work fine. I have added an example here. Hopefully, that can help you debug your implementation.

vineetparikh commented 3 months ago

Thanks so much! I'll use this and will let you know if I have any followup questions

hkchengrex commented 3 months ago

Sounds good.