deepglint / unicom

MLCD & UNICOM : Large-Scale Visual Representation Model
https://huggingface.co/DeepGlint-AI/mlcd-vit-large-patch14-336
398 stars 20 forks source link

Bug in CombinedMarginLoss implementation #11

Closed MaroonAmor closed 3 weeks ago

MaroonAmor commented 1 year ago

Hi @anxiangsir,

Thanks for sharing your work.

I have a question about the forward pass in CombinedMarginLoss when running sop_vit_b_16.sh as an example. In this case, self.m1 = 1.0, self.m2 = 0.25, and self.m3 = 0.0, But I think with torch.no_grad(), the gradients won't be propagated correctly, right?

It also seems that the implementation of CombinedMarginLoss is adapted from the insightface repo, and its previous version (without torch.no_grad()) makes more sense here: https://github.com/deepinsight/insightface/commit/657ae30e41fc53641a50a68694009d0530d9f6b3

Some issues raised for the same query: https://github.com/deepinsight/insightface/issues/2218, https://github.com/deepinsight/insightface/issues/2255, https://github.com/deepinsight/insightface/issues/2309

Why do we need torch.no_grad() here?

anxiangsir commented 1 year ago

Here we mainly adopted the implementation method of opensphere, and we found that this implementation method makes arcface more stable when training ViT.

MaroonAmor commented 1 year ago

@anxiangsir Thanks for getting back to me.

But it is not technically correct, right? The gradients won't be propagated back through those lines under torch.no_grad() (e.g., logits.arccos_()).

Also, I did a comparison experiment (w/ torch.no_grad() vs. w/o torch.no_grad() ) by running it on the SOP dataset using an A100 GPU. The performance w/o torch.no_grad() actually was better.

Any theory or math to support this change to add torch.no_grad()? This really confused me for a while. Thanks.