google-deepmind / tapnet

Tracking Any Point (TAP)
https://deepmind-tapir.github.io/blogpost.html
Apache License 2.0
1.25k stars 119 forks source link

`plot_tracks_v2` has bug when plotting with `trackgroup` argument. #92

Open chandlj opened 4 months ago

chandlj commented 4 months ago

I am running this notebook for RoboTAP clustering. After computing the clusters, I am running the following cell:

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

pointtrack_video = viz_utils.plot_tracks_v2(
    (demo_videos[demo_episode_ids[0]]).astype(np.uint8),
    separation_tracks_trim[demo_episode_ids[0]],
    1.0-separation_visibility_trim[demo_episode_ids[0]],
    trackgroup=clustered['classes']
)
media.show_video(pointtrack_video, fps=20)

However, the plot only shows about 10 points no matter how many points I track, and there are really no clusters to be found. I found that if I comment out trackgroup, then the plotting code works correctly and I can see the full range of points (although not colored with cluster ID). I can also verify that clusters are correctly computed by plotting individual frames like so:

separation_visibility_trim = clustered['separation_visibility']
separation_tracks_trim = clustered['separation_tracks']

frame = 35
plt.scatter(
  separation_tracks_trim["dummy_id"][:, frame, 0],
  separation_tracks_trim["dummy_id"][:, frame, 1],
  c=clustered["classes"],
  cmap="viridis",
)
plt.imshow(video[frame])

It's really only when trackgroup is specified that this code does not behave properly. Any ideas of how to fix?

cdoersch commented 4 months ago

Now that tapir_clustering.py is fixed, I've run the colab at head and verified that the code will plot more than 20 tracks. Your snippets above look correct to me--I don't see why it wouldn't plot the full set the way that the colab does. Maybe set a breakpoint at https://github.com/google-deepmind/tapnet/blob/main/utils/viz_utils.py#L193 and check what's being passed to plt.scatter?

chandlj commented 4 months ago

@cdoersch The most recent code that was pushed for tapir_clustering has a bug and did not work for me in the notebook. Looking at the commit here, it looks on line 574 changing len to np.prod is causing problems. I noticed that jax.tree_map(lambda x: np.prod(x.shape), query_features) actually returns shape 1 for the resolutions array, not 0.