schmidt-ju / crat-pred

Other
62 stars 11 forks source link

Loss dimensions #7

Closed Cram3r95 closed 1 year ago

Cram3r95 commented 1 year ago

Hi!

Your work is quite interesting. Could you provide more details about the dimensions here?

def prediction_loss(self, preds, gts):
        # Stack all the predicted trajectories of the target agent
        num_mods = preds.shape[2]
        # [0] is required to remove the unneeded dimensions
        preds = torch.cat([x[0] for x in preds], 0)

        # Stack all the true trajectories of the target agent
        # Keep in mind, that there are multiple trajectories in each sample, but only the first one ([0]) corresponds
        # to the target agent
        gt_target = torch.cat([torch.unsqueeze(x[0], 0) for x in gts], 0)
        gt_target = torch.repeat_interleave(gt_target, num_mods, dim=0)

        loss_single = self.reg_loss(preds, gt_target)
        loss_single = torch.sum(torch.sum(loss_single, dim=2), dim=1)

        loss_single = torch.split(loss_single, num_mods)

        # Tuple to tensor
        loss_single = torch.stack(list(loss_single), dim=0)

        min_loss_index = torch.argmin(loss_single, dim=1)

        min_loss_combined = [x[min_loss_index[i]]
                             for i, x in enumerate(loss_single)]

        loss_out = torch.sum(torch.stack(min_loss_combined))

        return loss_out

Specially in the line:

loss_single = torch.sum(torch.sum(loss_single, dim=2), dim=1)

I have tried, in both cases (GT and Pred): batch_size x num_modes x pred_len x data_dim (ej: 1024 x 6 x 30 x 2), and batch_size x pred_len x num_modes x 2, but I always get the same error:

image

What can I do?

schmidt-ju commented 1 year ago

Hey,

the input shapes of the loss are: preds: [N, 1, M, T, 2] gts: List of len N with each entry i having the shape [A_i, T, 2]

N: Batch size M: Number of modes (6 for Argoverse) T: Number of timesteps to be predicted (30 for Argoverse) A_i: Number of agents in sequence i

I agree that the format is a little bit confusing, which has a historical background: The initial idea was to implement the loss for all agents and not the target agent only. Therefore, the "gts" input is a list, with each entry containing the futures of all agents. The final version is trained with the target agent only. Therefore, the "preds" input has the shape "1" in the second dimension. I added this as a reminder that the loss is based on a single agent per sample (namely the target agent). There is no need to have a list in this case, because the number of agents is not variable anymore.

Let me know if this helps.

Julian