leondgarse / Keras_insightface

Insightface Keras implementation
MIT License
230 stars 56 forks source link

VPL - Variational Prototype Learning for deep face recognition #88

Open abdikaiym01 opened 2 years ago

abdikaiym01 commented 2 years ago

Do you have a plan to implement their(insightface) last work? It seems does not work as they claimed in their papers. https://github.com/deepinsight/insightface/issues/1801

abdikaiym01 commented 2 years ago

Is it the current SOTA for the face recognition task?

leondgarse commented 2 years ago

I think an imlementation is possible, will try and check if will be better.

leondgarse commented 2 years ago

VPL mode is added. It can be enabled by tt = train.Train(..., use_vpl=True). I think it should be same with the official implementation. Here is my test results using basic EfficientNetV2B0 + AdamW:

import losses, train, models
import tensorflow_addons as tfa
keras.mixed_precision.set_global_policy("mixed_float16")

data_basic_path = '/datasets/ms1m-retinaface-t1'
data_path = data_basic_path + '_112x112_folders'
eval_paths = [os.path.join(data_basic_path, ii) for ii in ['lfw.bin', 'cfp_fp.bin', 'agedb_30.bin']]

from keras_cv_attention_models import efficientnet
basic_model = efficientnet.EfficientNetV2B0(input_shape=(112, 112, 3), num_classes=0)
basic_model = models.buildin_models(basic_model, dropout=0, emb_shape=512, output_layer='GDC', bn_epsilon=1e-4, bn_momentum=0.9, scale=True, use_bias=False)

tt = train.Train(data_path, eval_paths=eval_paths,
    save_path='TT_efv2_b0_swish_GDC_arc_emb512_dr0_adamw_5e4_bs512_ms1mv3_randaug_cos16_batch_float16_vpl.h5',
    basic_model=basic_model, model=None, lr_base=0.01, lr_decay=0.5, lr_decay_steps=16, lr_min=1e-6, lr_warmup_steps=3,
    batch_size=512, random_status=100, eval_freq=4000, output_weight_decay=1, use_vpl=True)

import tensorflow_addons as tfa
optimizer = tfa.optimizers.AdamW(learning_rate=1e-2, weight_decay=5e-4, exclude_from_weight_decay=["/gamma", "/beta"])

sch = [
    {"loss": losses.ArcfaceLoss(scale=16), "epoch": 4, "optimizer": optimizer},
    {"loss": losses.ArcfaceLoss(scale=32), "epoch": 3},
    {"loss": losses.ArcfaceLoss(scale=64), "epoch": 46},
]
tt.train(sch, 0)
exit()
Plot Selection_458 Results VPL lfw cfp_fp agedb_30 IJBB 1e-4 IJBC 1e-4
False 0.997667 0.979429 0.978333 0.941188 0.955719
True 0.997667 0.979571 0.978500 0.938559 0.955054
abdikaiym01 commented 2 years ago

In the vpl paper the results were absolutely different. It turns out the reality quite another. What do you think about it? Maybe your implementation a little different than their, though I'm not sure about it?

leondgarse commented 2 years ago

Ya, I have compared them several times. It seems the main parts are:

leondgarse commented 2 years ago

Here is the result using start_iters=8000 / 4, allowed_delta=200 / 4 for batch_size=512: Selection_458

Results VPL lfw cfp_fp agedb_30 IJBB 1e-4 IJBC 1e-4
False 0.997667 0.979429 0.978333 0.941188 0.955719
start 8000, delta 200 0.997667 0.979571 0.978500 0.938559 0.955054
start 2000, delta 50 0.998000 0.976429 0.977667 0.940117 0.956128
IJBB / IJBC detail VPL 1e-06 1e-05 0.0001 0.001 0.01 0.1 AUC
False, IJBB 0.338948 0.875365 0.941188 0.960467 0.974684 0.983642 0.991774
start 8000, delta 200, IJBB 0.376241 0.8815 0.938559 0.962902 0.976339 0.985881 0.992184
start 2000, delta 50, IJBB 0.353944 0.874002 0.940117 0.961538 0.974684 0.983934 0.991567
False, IJBC 0.848954 0.927954 0.955719 0.972184 0.982462 0.989109 0.994352
start 8000, delta 200, IJBC 0.877895 0.928568 0.955054 0.973513 0.983689 0.990387 0.994527
start 2000, delta 50, IJBC 0.867004 0.926778 0.956128 0.972797 0.982257 0.989211 0.994179

~This is the default adjusment now: self.start_iters, self.allowed_delta = 8000 * 128 // batch_size, 200 * 128 // batch_size~. I think it do make some sense, especially for IJBB / IJBC 1e-6 accuracy. Also notice start 8000, delta 200 is actually higher in all TAR@FAR just except 1e-4... It may worth a try in some situations.

leondgarse commented 2 years ago

It's 2 parameters now, vpl_start_iters and vpl_allowed_delta, use_vpl is abandoned. VPL mode is enabled by setting vpl_start_iters > 0 now, like tt = train.Train(..., vpl_start_iters=8000). Default vpl_start_iters=-1, vpl_allowed_delta=200.

abdikaiym01 commented 2 years ago

Thank you for your work. It's indeed worth it try. And additional question about IJB validation dataset: Did you try use their 1:N test?

leondgarse commented 2 years ago

I'm using my IJB_evals.py. Just ran a bunch of 1:N tests, VPL start 8000, delta 200 performs not bad in this test: