jonasgrebe / pt-femb-face-embeddings

Implementation of related angular-margin-based classification loss functions for training (face) embedding models: SphereFace, CosFace, ArcFace and MagFace.
MIT License
23 stars 6 forks source link

Found two flaws in ArcMarginHeader #3

Closed alexkrz closed 4 months ago

alexkrz commented 5 months ago

Hi,

I stumbled across your repository to compare different approaches for computing face embeddings. I like your approach to construct the ArcMarginHeader class as base class for all other approaches. However, when training a model with the ArcFaceHeader, I noticed two flaws in your ArcMarginHeader class.

Both flaws appear in the forward() function:

  1. line 28:

    self.linear.weight = torch.nn.Parameter(self.normalize(self.linear.weight))

    You need to initialize the weight parameter when initializing the header class instead of initiliazing it again in every forward pass.

  2. line 32:

    theta = torch.acos(logits).clamp(-1, 1)

    The arccos must not be clamped to (-1, 1), as the returned function values are between 0 and pi.

I will create a PR that fixes both issues.