facebookresearch / co-tracker

CoTracker is a model for tracking any point (pixel) on a video.
https://co-tracker.github.io/
Other
2.72k stars 195 forks source link

CoTracker fails to track points on small number of frames #40

Closed MaxTeselkin closed 11 months ago

MaxTeselkin commented 11 months ago

Hi, thanks for interesting architecture! I have found an interesting behaviour of a model: when I try to track a point on small number of frames (3 frames), model always predicts zero coordinates (0, 0). But when I track the same point on the same video for longer number of frames (15, for example), it tracks really good.

Is it a bug or I am using model incorrectly?

My usage is quite standard:

input_video = torch.from_numpy(np.array(rgb_images)).permute(0, 3, 1, 2)[None].float()
input_video = input_video.to(self.device)
query = torch.tensor([[0, point_x, point_y]]).float()
query = query.to(self.device)
pred_tracks, pred_visibility = self.model(input_video, queries=query[None])
pred_tracks = pred_tracks.squeeze().cpu()[1:]

I guess CoTracker was trained on dataset which does not contain sequences as short as 3 frames, so model fails on such types of data. Am I right?

nikitakaraevv commented 11 months ago

Hi @MaxTeselkin, thank you for the question!

We did not expect that CoTracker would be applied to such short videos :) I think the reason why it outputs zeros is because of this line: https://github.com/facebookresearch/co-tracker/blob/4f297a92fe1a684b1b0980da138b706d62e45472/cotracker/models/core/cotracker/cotracker.py#L264 self.S (window size) is 8, so T - self.S // 2 in this case is -1, which leads to the while loop being skipped. It would be great to make CoTracker work for such short videos as well. I'll think about it!

MaxTeselkin commented 11 months ago

Honestly I was able to make it work on short sequences using a simple trick to artificially lengthen input sequence @nikitakaraevv. If length of input sequence is lower that 11, I simply lengthen it by duplicating the last frame as much as needed. For example, if input length is 4 (input frame with points + 3 frames to track), I paste 7 duplicates of the last frame into the frames list, pass them to model and shorten predictions list in the end.

Here is how it looks like:

# cotracker fails to track short sequences, so it is necessary to lengthen them by duplicating last frame
lengthened = False
if len(rgb_images) < 11:
    lengthened = True
    original_length = len(rgb_images) - 1  # do not include input frame
    while len(rgb_images) < 11:
        rgb_images.append(rgb_images[-1])
# disable gradient calculation
torch.set_grad_enabled(False)
input_video = torch.from_numpy(np.array(rgb_images)).permute(0, 3, 1, 2)[None].float()
input_video = input_video.to(self.device)
query = torch.tensor([[0, point_x, point_y]]).float()
query = query.to(self.device)
pred_tracks, pred_visibility = self.model(input_video, queries=query[None])
pred_tracks = pred_tracks.squeeze().cpu()[1:]
if lengthened:
    pred_tracks = pred_tracks[:original_length] # shorten output if necessary

And it works perfectly - now predictions look nice even if I track on only one frame.

Regarding the reason for using such short sequences - I simply used them for debugging and thought that my code was incorrect until I tried to track on longer sequences) Anyway, in my opinion CoTracker is the best model for point tracking right now, good job!