Open chandlj opened 4 months ago
Now that tapir_clustering.py is fixed, I've run the colab at head and verified that the code will plot more than 20 tracks. Your snippets above look correct to me--I don't see why it wouldn't plot the full set the way that the colab does. Maybe set a breakpoint at https://github.com/google-deepmind/tapnet/blob/main/utils/viz_utils.py#L193 and check what's being passed to plt.scatter?
@cdoersch The most recent code that was pushed for tapir_clustering
has a bug and did not work for me in the notebook. Looking at the commit here, it looks on line 574 changing len
to np.prod
is causing problems. I noticed that jax.tree_map(lambda x: np.prod(x.shape), query_features)
actually returns shape 1
for the resolutions array, not 0.
I am running this notebook for RoboTAP clustering. After computing the clusters, I am running the following cell:
However, the plot only shows about 10 points no matter how many points I track, and there are really no clusters to be found. I found that if I comment out
trackgroup
, then the plotting code works correctly and I can see the full range of points (although not colored with cluster ID). I can also verify that clusters are correctly computed by plotting individual frames like so:It's really only when
trackgroup
is specified that this code does not behave properly. Any ideas of how to fix?