sshaoshuai / MTR

MTR: Motion Transformer with Global Intention Localization and Local Movement Refinement, NeurIPS 2022.
Apache License 2.0
661 stars 105 forks source link

do not understand gmm loss #28

Closed chrisxuxuhui closed 1 year ago

chrisxuxuhui commented 1 year ago

if use_square_gmm: log_std1 = log_std2 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) std1 = std2 = torch.exp(log_std1) # (0.2m to 150m) rho = torch.zeros_like(log_std1) else: log_std1 = torch.clip(nearest_trajs[:, :, 2], min=log_std_range[0], max=log_std_range[1]) log_std2 = torch.clip(nearest_trajs[:, :, 3], min=log_std_range[0], max=log_std_range[1]) std1 = torch.exp(log_std1) # (0.2m to 150m) std2 = torch.exp(log_std2) # (0.2m to 150m) rho = torch.clip(nearest_trajs[:, :, 4], min=-rho_limit, max=rho_limit)

gt_valid_mask = gt_valid_mask.type_as(pred_scores)
if timestamp_loss_weight is not None:
    gt_valid_mask = gt_valid_mask * timestamp_loss_weight[None, :]

# -log(a^-1 * e^b) = log(a) - b
reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2)  # (batch_size, num_timestamps)
reg_gmm_exp = (0.5 * 1 / (1 - rho**2)) * ((dx**2) / (std1**2) + (dy**2) / (std2**2) - 2 * rho * dx * dy / (std1 * std2))  # (batch_size, num_timestamps)

reg_loss = ((reg_gmm_log_coefficient + reg_gmm_exp) * gt_valid_mask).sum(dim=-1)

hello i do not understand this gmm loss could you please tell me detail ?

sshaoshuai commented 1 year ago

We model the agent's multimodal future trajectories with the Gaussian Mixture Modeling, and this loss aims to maximize the likelihood of its GT future trajectory based on such GMM modeling. For more illustration, you can refer to our supp or this paper https://arxiv.org/pdf/1910.05449.pdf.

chrisxuxuhui commented 1 year ago

@sshaoshuai thank you very much for your answer . I also read the multipath . as said in multi path The time-step distributions are assumed to be conditionally independent given an anchor, i.e., we write φ(st|·) instead of φ(st|·, s1:t−1). This modeling assumption allows us to predict for all time steps jointly with a single inference pass, making our model simple to train and efficient to evaluate. If desired, it is straightforward to add a conditional next-time-step dependency to our model, using a recurrent structure (RNN). I wonder why assumed every step time is independent , do you think it is ok? or just for easy to train ? I could not understand it . In your paper also assumed it , could you know why ?

sshaoshuai commented 1 year ago

As you mentioned, "simple to train and efficient to evaluate", I think that's the answer. Of course, I think learning it with an autoregressive way may also achieve good performance with careful parameter-tuning.

QC625 commented 1 year ago

@sshaoshuai Hello, I also have confuse about the GMM loss. The calculation of this loss is based on Eq. (10) in the article, but there is no - log (ph) when implemented in the code. Is it loss_cls? image

sshaoshuai commented 1 year ago

Yes, it corresponds to the mode classification loss. @QC625

Xinchengzelin commented 1 year ago

10) in the article, b

Hello, is this Eq.(10) designed by yourself? I find this material , when I caculate the loss as the Equation in Page5 in this material, the numerial instablity happens, so I can't training the model.