facebookresearch / co-tracker

CoTracker is a model for tracking any point (pixel) on a video.
https://co-tracker.github.io/
Other
3.93k stars 253 forks source link

CoTracker fails to track points on small number of frames #40

Closed MaxTeselkin closed 1 year ago

MaxTeselkin commented 1 year ago

Hi, thanks for interesting architecture! I have found an interesting behaviour of a model: when I try to track a point on small number of frames (3 frames), model always predicts zero coordinates (0, 0). But when I track the same point on the same video for longer number of frames (15, for example), it tracks really good.

Is it a bug or I am using model incorrectly?

My usage is quite standard:

input_video = torch.from_numpy(np.array(rgb_images)).permute(0, 3, 1, 2)[None].float()
input_video = input_video.to(self.device)
query = torch.tensor([[0, point_x, point_y]]).float()
query = query.to(self.device)
pred_tracks, pred_visibility = self.model(input_video, queries=query[None])
pred_tracks = pred_tracks.squeeze().cpu()[1:]

I guess CoTracker was trained on dataset which does not contain sequences as short as 3 frames, so model fails on such types of data. Am I right?

nikitakaraevv commented 1 year ago

Hi @MaxTeselkin, thank you for the question!

We did not expect that CoTracker would be applied to such short videos :) I think the reason why it outputs zeros is because of this line: https://github.com/facebookresearch/co-tracker/blob/4f297a92fe1a684b1b0980da138b706d62e45472/cotracker/models/core/cotracker/cotracker.py#L264 self.S (window size) is 8, so T - self.S // 2 in this case is -1, which leads to the while loop being skipped. It would be great to make CoTracker work for such short videos as well. I'll think about it!

MaxTeselkin commented 1 year ago

Honestly I was able to make it work on short sequences using a simple trick to artificially lengthen input sequence @nikitakaraevv. If length of input sequence is lower that 11, I simply lengthen it by duplicating the last frame as much as needed. For example, if input length is 4 (input frame with points + 3 frames to track), I paste 7 duplicates of the last frame into the frames list, pass them to model and shorten predictions list in the end.

Here is how it looks like:

# cotracker fails to track short sequences, so it is necessary to lengthen them by duplicating last frame
lengthened = False
if len(rgb_images) < 11:
    lengthened = True
    original_length = len(rgb_images) - 1  # do not include input frame
    while len(rgb_images) < 11:
        rgb_images.append(rgb_images[-1])
# disable gradient calculation
torch.set_grad_enabled(False)
input_video = torch.from_numpy(np.array(rgb_images)).permute(0, 3, 1, 2)[None].float()
input_video = input_video.to(self.device)
query = torch.tensor([[0, point_x, point_y]]).float()
query = query.to(self.device)
pred_tracks, pred_visibility = self.model(input_video, queries=query[None])
pred_tracks = pred_tracks.squeeze().cpu()[1:]
if lengthened:
    pred_tracks = pred_tracks[:original_length] # shorten output if necessary

And it works perfectly - now predictions look nice even if I track on only one frame.

Regarding the reason for using such short sequences - I simply used them for debugging and thought that my code was incorrect until I tried to track on longer sequences) Anyway, in my opinion CoTracker is the best model for point tracking right now, good job!