zhmiao / OpenLongTailRecognition-OLTR

Pytorch implementation for "Large-Scale Long-Tailed Recognition in an Open World" (CVPR 2019 ORAL)
BSD 3-Clause "New" or "Revised" License
841 stars 128 forks source link

about center loss implement #10

Closed valencebond closed 5 years ago

valencebond commented 5 years ago

i am not very sure about center loss implement. Why do we need implement attracting loss using DiscCentroidsLossFunc(torch.autograd.function)? Can we just implement as repelling loss?

zhmiao commented 5 years ago

Hello @valencebond

First of all, by using the attracting loss, the classification confidence could be higher. At the same time, since we are also trying to use distance as the scoring metric for open class detection. If we force the samples of each class to be more compact to their corresponding centroids, during testing, with the help of reachability, it is more likely that samples of open class can have larger distances to centroids compared to samples of known classes. Does this make sense to you?

valencebond commented 5 years ago

hi @zhmiao, thanks for your detailed explanation. Maybe I didn't express my problem clearly. My question is why we need to calculate attract loss gradient by handwriting using DiscCentroidsLossFunc backward functions. Is it essential for DiscCentroidsLossFunc backward function? Can we just compute loss as forward of DiscCentroidsLossFunc without backward function.

zhmiao commented 5 years ago

Hello @valencebond Actually there was no good reason why we use this specific implementation. It was possible to simply write a forward function and let pytorch handle the backward. We believe both way should work as fine.

valencebond commented 5 years ago

thanks for your replay~