there seems to be an issue with data types when running your validation scripts.
The transformation function torch_pos2global (issue) and the scores dict in the WaymoPostProcessing class expect double values:
File ".../data_modules/waymo_post_processing.py", line 118, in forward
waymo_scores[scene_indices, pred_dict["ref_idx"]] = scores # [n_scene, n_agent, k_pred]
RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.
Therefore, my change casts the corresponding tensors to double.
Hi,
there seems to be an issue with data types when running your validation scripts. The transformation function
torch_pos2global
(issue) and the scores dict in theWaymoPostProcessing
class expect double values:Therefore, my change casts the corresponding tensors to double.