Large-Trajectory-Model / ATM

Official codebase for "Any-point Trajectory Modeling for Policy Learning"
https://xingyu-lin.github.io/atm/
MIT License
182 stars 19 forks source link

Bug in `flow_utils.py` leading to incorrect sampling of visible track points #17

Closed ahadjawaid closed 1 month ago

ahadjawaid commented 1 month ago

Bug in flow_utils.py leading to incorrect sampling of visible track points

In the code, the following line:

vis_idx = torch.where(vis[0] > 0)[0]

should be:

vis_idx = torch.where(vis[0] > 0)[1]

Currently, the code is only sampling indices from the first 16 points rather than sampling randomly from visible points across all the pixels. This is because the dimension [0] selects from the batch indices rather than the track point indices. Changing it to [1] will correctly sample from the visible track points.

Link to code: flow_utils.py#L147

Holmes-GU commented 1 month ago

Actually, vis_idx = torch.where(vis[0] > 0)[0] is right. The dim of vis is [T,N].

ahadjawaid commented 1 month ago

Unless I misunderstood the point of getting vis_idx is to sample the points that are most likely to be visible. And if you do [0] it will only give you indices of that are in range of T dimension. So, if you use theses indices in the following code sampled_tracks = tracks[:, vis_idx[idx]] you will select only points that are in range of T and not full N.

Holmes-GU commented 1 month ago

The dim of vis is [T, N] and the dim of vis[0] is [N]. So the dim of torch.where(vis[0] > 0) is [N'] (N'<=N) which indicates dim in N.