cuishuhao / BNM

code of Towards Discriminability and Diversity: Batch Nuclear-norm Maximization under Label Insufficient Situations (CVPR2020 oral)
MIT License
263 stars 30 forks source link

Is it right to use `torch.mean(s_tgt)` when C < B? #13

Open cantabile-kwok opened 2 years ago

cantabile-kwok commented 2 years ago

Hi and I am studying your approach with your implementation. My question is that in your paper you use (Equation 12) to compute the BNM loss, and the divisor is the batch size. But in BNM/DA/BNM/train_image.py L#164 I found that this is done with torch.mean(). Then if the class number is smaller than batch size, the SVD operation will generate a s_tgt with length C instead of B. Wouldn't that be incorrect according to the original equation? Why don't explicitly divide with the batch size?

cuishuhao commented 2 years ago

I admit I am a little careless about the weight. In the equation, batch B is divided. In the code, I achieve it with min(B, C). L_bnm can be combined with a hyperparameter \lambda, and in this case the actual value of \lambda is changed. In other practices, I find the value might be better to set as \sqrt(B * C), and the performance can be better with different values of \lambda.