Closed Cram3r95 closed 1 year ago
Hey,
the input shapes of the loss are: preds: [N, 1, M, T, 2] gts: List of len N with each entry i having the shape [A_i, T, 2]
N: Batch size M: Number of modes (6 for Argoverse) T: Number of timesteps to be predicted (30 for Argoverse) A_i: Number of agents in sequence i
I agree that the format is a little bit confusing, which has a historical background: The initial idea was to implement the loss for all agents and not the target agent only. Therefore, the "gts" input is a list, with each entry containing the futures of all agents. The final version is trained with the target agent only. Therefore, the "preds" input has the shape "1" in the second dimension. I added this as a reminder that the loss is based on a single agent per sample (namely the target agent). There is no need to have a list in this case, because the number of agents is not variable anymore.
Let me know if this helps.
Julian
Hi!
Your work is quite interesting. Could you provide more details about the dimensions here?
Specially in the line:
loss_single = torch.sum(torch.sum(loss_single, dim=2), dim=1)
I have tried, in both cases (GT and Pred): batch_size x num_modes x pred_len x data_dim (ej: 1024 x 6 x 30 x 2), and batch_size x pred_len x num_modes x 2, but I always get the same error:
What can I do?