Joker316701882 / Additive-Margin-Softmax

This is the implementation of paper <Additive Margin Softmax for Face Verification>
492 stars 149 forks source link

Hi, How can I use function AM_logits_compute? #1

Closed yuan-wenhua closed 6 years ago

yuan-wenhua commented 6 years ago

I tried to use your resface.py as inference network while using Facenet as trainning logic. But I am not sure how to use function AM_logits_compute, neither using it to replace prelogits_center_loss nor using it to replace 'logits' parameter of tf.nn.sparse_softmax_cross_entropy_with_logits worked. Can you tell me how to use funtion AM_logits_compute? Thanks!

Joker316701882 commented 6 years ago

@yuan-wenhua Thank you for your interest! It basically works like this:

prelogits,_ = network.inference(...)
embeddings = tf.nn.l2_normalize(prelogits,...)
AM_logits = AM_logits_compute(embeddings, label_batch, args, nrof_classes)
...
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label_batch,logits=AM_logits,name='...')

The AM_logits_compute function receives 'embeddings' as input and output logits which can be directly put into some_loss_function_with_logits().

yuan-wenhua commented 6 years ago

Thank you for your reply! I tried the code above and got weird result, the Loss decreased to about 18, then stayed. The lfw result is poor. Do you have any idea about this problem? I also tried to change s from 30 to 64, the result is similar. Thanks!

Joker316701882 commented 6 years ago

@yuan-wenhua There is some weird part in Sandbergs facenet. Look into the detail of facenet.train() function, it forces all parameters in model to use moving average value. In this case, it seems no need to use batch norm correctly. (look at the batch_norm_params of his code, is_training is not correctly set). I'm not sure whether it works correctly.

So I did modify on his code. Basically, replace facenet.train() with a simple optimizer: facenet.train() tf.train.AdamOptimizer() And set is_training correctly in resface.py

Other detail of training is same in AM softmax except I used Adam , batch norm and no weight decay (best configuration I got). Also pay attention to image size, alignment method, it will have great effect to final result.

BTW, my experiment is based on vggface2. Using webface will slightly reduce the accuracy.

yuan-wenhua commented 6 years ago

Thank you very much for your detail explanation! After I replace facenet.train() with tf.train.AdamOptimizer(), now the loss is continuous decreasing.

There is a little issue, the second link of your lost comment "AM-softmax" can't be opened, it shows "https://camo.githubusercontent.com/07de533f3c6e15c4eb0f9e91d75dea7a8a7a3974/68747470733a2f2f61727869762e6f72672f6162732f313830312e3035353939".

Thanks!

Joker316701882 commented 6 years ago

@yuan-wenhua Fixed! Thank you : )

yuan-wenhua commented 6 years ago

Hi, coming again~ Thank you for your help in last issue, now I can train, the loss is continuous decreasing, but the accuracy result is not higher than 0.73 even when loss is lower than 2. I am not sure what's the problem. My dataset is ms-celeb-1M, the inference net using your code:resface.py, do you have any idea about my problem? I see you will share your trainning code soon., I am looking forward to this. Thank you!

Joker316701882 commented 6 years ago

@yuan-wenhua Shared! Open new issues if you meet new questions.

caocuong0306 commented 6 years ago

Hi @Joker316701882 ,

Regarding face detection/alignment part, I have 2 questions. Hope to discuss with you.

  1. Center Crop: When MTCNN cannot detect face, you use center crop to locate face region. Is this the suitable way to increase the accuracy for training (more data)? I concern the cropped regions may not be suitable for alignment. In fact, images aligned from cropped regions still contain a lot of background, and it's not very different from cropping only (without alignment).
  2. Alignment or No alignment: Different from Sandbergs's approach, you perform alignment instead of cropping. When Sandbergs argues that keeping cropped face region with background helps the model generalizes better (https://github.com/davidsandberg/facenet/issues/93), you perform alignment. In your experience, which one performs better. As you indicated the quality of alignment might be the bottleneck for modern face recognition, I believe you already tested with these two.

Looking forward to hearing from you. Thank you.

Joker316701882 commented 5 years ago

@caocuong0306 Sorry for late reply.

  1. Center crop: I think better choice is that when training on your own dataset, it's better to abandon those images in which there is no detected faces, because it's highly possible that there are some images with on faces in our own dataset. But when you want to compete on some public benchmark such as Megaface, the suitable way is to log those images in which there is no detected faces, because usually that means the fault of face detection alrorithm. And then manually crop and align it.

  2. Alignment or not: This is interesting question, I can't give an absolute right answer. But in my experience, alignment is always better. Notice that the code from Sandberg actually apply too much extra tricks, which make it not a normal face recognition code, so the conclusion he draw may differ from us.