zhejz / HPTR

Real-Time Motion Prediction via Heterogeneous Polyline Transformer with Relative Pose Encoding. NeurIPS 2023.
https://zhejz.github.io/hptr
Other
116 stars 9 forks source link

Cast tensors to double to run validation #7

Closed roydenwa closed 6 months ago

roydenwa commented 6 months ago

Hi,

there seems to be an issue with data types when running your validation scripts. The transformation function torch_pos2global (issue) and the scores dict in the WaymoPostProcessing class expect double values:

File ".../data_modules/waymo_post_processing.py", line 118, in forward
    waymo_scores[scene_indices, pred_dict["ref_idx"]] = scores  # [n_scene, n_agent, k_pred]
RuntimeError: Index put requires the source and destination dtypes match, got Double for the destination and Float for the source.

Therefore, my change casts the corresponding tensors to double.