NVlabs / AL-MDN

Official pytorch implementation of Active Learning for deep object detection via probabilistic modeling (ICCV 2021)
https://openaccess.thecvf.com/content/ICCV2021/html/Choi_Active_Learning_for_Deep_Object_Detection_via_Probabilistic_Modeling_ICCV_2021_paper.html
Other
167 stars 23 forks source link

Why using reparameterizaiton-trick for classification loss computation #12

Closed TianpengBu closed 2 years ago

TianpengBu commented 2 years ago

Hi author, Thanks for your great work. After reading the paper, I have a question regarding the computation of the classification loss.

It's clear that, for the localization loss, you regress the mean of GMM w.r.t the offset of the anchor to GT boxes, and use the variance term to predict the un-certainty of the offset prediction, and optimize this procedure with Log-likelihood Maximizaiton.

However, for classification, as far as I am concerned, you do the optimization in another way. You treat the input data as random variable, and by using the re-parameterization trick, you get the sampled class-specific random variable from the learned GMM, and finally compute the BCE loss between GT and re-parameterized random variable.

My question is that, why do this in this way? Can we just compute the classification loss in a similar way as the localization? Doing something like Maximize the likelihood of pos, and neg samples given the predicted mean and variance of GMM, like N(GT_pos | mu_p, Sigma_p), N(GT_neg | mu_p, Sigma_p), where mu_p and Sigma_p are computed by the network.

I hope my puzzel could be considered,

Best regards

jwchoi384 commented 2 years ago

Hello @TianpengBu As we know, output of regression task is in a continuous distribution and output of classification task is in a discrete distribution. To apply a Gaussian distribution to a classification, we need to do sampling to obtain discrete value in the distribution. However, sampling is not differentiated, so it is not be backpropagated. That's why we use the reparameterization trick to enable differentiation. You can easily find relevant content in papers in Auto-Encoder area.