Open abdikaiym01 opened 2 years ago
Is it the current SOTA for the face recognition task?
I think an imlementation is possible, will try and check if will be better.
VPL mode is added. It can be enabled by tt = train.Train(..., use_vpl=True)
. I think it should be same with the official implementation. Here is my test results using basic EfficientNetV2B0 + AdamW
:
import losses, train, models
import tensorflow_addons as tfa
keras.mixed_precision.set_global_policy("mixed_float16")
data_basic_path = '/datasets/ms1m-retinaface-t1'
data_path = data_basic_path + '_112x112_folders'
eval_paths = [os.path.join(data_basic_path, ii) for ii in ['lfw.bin', 'cfp_fp.bin', 'agedb_30.bin']]
from keras_cv_attention_models import efficientnet
basic_model = efficientnet.EfficientNetV2B0(input_shape=(112, 112, 3), num_classes=0)
basic_model = models.buildin_models(basic_model, dropout=0, emb_shape=512, output_layer='GDC', bn_epsilon=1e-4, bn_momentum=0.9, scale=True, use_bias=False)
tt = train.Train(data_path, eval_paths=eval_paths,
save_path='TT_efv2_b0_swish_GDC_arc_emb512_dr0_adamw_5e4_bs512_ms1mv3_randaug_cos16_batch_float16_vpl.h5',
basic_model=basic_model, model=None, lr_base=0.01, lr_decay=0.5, lr_decay_steps=16, lr_min=1e-6, lr_warmup_steps=3,
batch_size=512, random_status=100, eval_freq=4000, output_weight_decay=1, use_vpl=True)
import tensorflow_addons as tfa
optimizer = tfa.optimizers.AdamW(learning_rate=1e-2, weight_decay=5e-4, exclude_from_weight_decay=["/gamma", "/beta"])
sch = [
{"loss": losses.ArcfaceLoss(scale=16), "epoch": 4, "optimizer": optimizer},
{"loss": losses.ArcfaceLoss(scale=32), "epoch": 3},
{"loss": losses.ArcfaceLoss(scale=64), "epoch": 46},
]
tt.train(sch, 0)
exit()
Plot Results | VPL | lfw | cfp_fp | agedb_30 | IJBB 1e-4 | IJBC 1e-4 |
---|---|---|---|---|---|---|
False | 0.997667 | 0.979429 | 0.978333 | 0.941188 | 0.955719 | |
True | 0.997667 | 0.979571 | 0.978500 | 0.938559 | 0.955054 |
In the vpl paper the results were absolutely different. It turns out the reality quite another. What do you think about it? Maybe your implementation a little different than their, though I'm not sure about it?
Ya, I have compared them several times. It seems the main parts are:
def prepare_queue_lambda(self, label, iters):
self.queue_lambda[:] = 0.0
if iters>self.cfg['start_iters']:
allowed_delta = self.cfg['allowed_delta']
if self.vpl_mode==0:
past_iters = iters - self.queue_iters
idx = torch.where(past_iters <= allowed_delta)[0]
self.queue_lambda[idx] = self.cfg['lambda']
Here it's models.py#L247
queue_lambda = tf.cond(
self.iters > self.start_iters,
lambda: tf.where(self.iters - self.queue_iters <= self.allowed_delta, self.vpl_lambda, 0.0), # prepare_queue_lambda
lambda: self.zero_queue_lambda,
)
...
_lambda = self.queue_lambda.view(self.num_local, 1)
injected_weight = norm_weight*(1.0-_lambda) + self.queue*_lambda
injected_norm_weight = normalize(injected_weight)
...
Here models.py#L254
norm_w = K.l2_normalize(self.w, axis=0)
injected_weight = norm_w * (1 - queue_lambda) + tf.transpose(self.queue_features) * queue_lambda
injected_norm_weight = K.l2_normalize(injected_weight, axis=0)
def set_queue(self, total_features, total_label, index_positive, iters):
local_label = total_label[index_positive]
sel_features = normalize(total_features[index_positive,:])
self.queue[local_label,:] = sel_features
self.queue_iters[local_label] = iters
Here in myCallbacks/VPLUpdateQueue
class VPLUpdateQueue(keras.callbacks.Callback):
def __init__(self):
super().__init__()
def on_batch_end(self, batch, logs=None):
batch_labels_back_up = self.model.loss[0].batch_labels_back_up
update_label_pos = tf.expand_dims(batch_labels_back_up, 1)
vpl_norm_dense_layer = self.model.layers[-1]
updated_queue = tf.tensor_scatter_nd_update(vpl_norm_dense_layer.queue_features, update_label_pos, vpl_norm_dense_layer.norm_features)
vpl_norm_dense_layer.queue_features.assign(updated_queue)
iters = tf.repeat(vpl_norm_dense_layer.iters, tf.shape(batch_labels_back_up)[0])
updated_queue_iters = tf.tensor_scatter_nd_update(vpl_norm_dense_layer.queue_iters, update_label_pos, iters)
vpl_norm_dense_layer.queue_iters.assign(updated_queue_iters)
As it needs the true labels, have to do this update outside of model. I think the logic is matching with official one, or if I'm missing something...
'start_iters': 8000, 'allowed_delta': 200
for batch_size = 128
. I'm using that for batch_size=512
, may try start_iters=8000 / 4, allowed_delta=200 / 4
for batch_size=512
later.Here is the result using start_iters=8000 / 4, allowed_delta=200 / 4
for batch_size=512
:
Results | VPL | lfw | cfp_fp | agedb_30 | IJBB 1e-4 | IJBC 1e-4 |
---|---|---|---|---|---|---|
False | 0.997667 | 0.979429 | 0.978333 | 0.941188 | 0.955719 | |
start 8000, delta 200 | 0.997667 | 0.979571 | 0.978500 | 0.938559 | 0.955054 | |
start 2000, delta 50 | 0.998000 | 0.976429 | 0.977667 | 0.940117 | 0.956128 |
IJBB / IJBC detail | VPL | 1e-06 | 1e-05 | 0.0001 | 0.001 | 0.01 | 0.1 | AUC |
---|---|---|---|---|---|---|---|---|
False, IJBB | 0.338948 | 0.875365 | 0.941188 | 0.960467 | 0.974684 | 0.983642 | 0.991774 | |
start 8000, delta 200, IJBB | 0.376241 | 0.8815 | 0.938559 | 0.962902 | 0.976339 | 0.985881 | 0.992184 | |
start 2000, delta 50, IJBB | 0.353944 | 0.874002 | 0.940117 | 0.961538 | 0.974684 | 0.983934 | 0.991567 | |
False, IJBC | 0.848954 | 0.927954 | 0.955719 | 0.972184 | 0.982462 | 0.989109 | 0.994352 | |
start 8000, delta 200, IJBC | 0.877895 | 0.928568 | 0.955054 | 0.973513 | 0.983689 | 0.990387 | 0.994527 | |
start 2000, delta 50, IJBC | 0.867004 | 0.926778 | 0.956128 | 0.972797 | 0.982257 | 0.989211 | 0.994179 |
~This is the default adjusment now: self.start_iters, self.allowed_delta = 8000 * 128 // batch_size, 200 * 128 // batch_size
~. I think it do make some sense, especially for IJBB / IJBC 1e-6
accuracy. Also notice start 8000, delta 200
is actually higher in all TAR@FAR
just except 1e-4
... It may worth a try in some situations.
It's 2 parameters now, vpl_start_iters
and vpl_allowed_delta
, use_vpl
is abandoned. VPL mode is enabled by setting vpl_start_iters > 0
now, like tt = train.Train(..., vpl_start_iters=8000)
. Default vpl_start_iters=-1, vpl_allowed_delta=200
.
Thank you for your work. It's indeed worth it try. And additional question about IJB validation dataset: Did you try use their 1:N test?
I'm using my IJB_evals.py
. Just ran a bunch of 1:N
tests, VPL start 8000, delta 200
performs not bad in this test:
>>>> Gallery 1 top1: 0.972289, top5: 0.979864, top10: 0.980662
>>>> Gallery 2 top1: 0.933955, top5: 0.958698, top10: 0.966882
>>>> Mean [Mean] top1: 0.952678, top5: 0.969036, top10: 0.973612
far | g1_tpir | g1_thresh | g2_tpir | g2_thresh | mean_tpir |
---|---|---|---|---|---|
0.0001 | 0.188596 | 0.849763 | 0.032737 | 0.923226 | 0.110667 |
0.001 | 0.361443 | 0.796886 | 0.169775 | 0.843299 | 0.265609 |
0.01 | 0.919458 | 0.406492 | 0.881043 | 0.37894 | 0.90025 |
0.1 | 0.962919 | 0.266365 | 0.924629 | 0.264748 | 0.943774 |
1 | 0.972289 | 0.12289 | 0.933955 | 0.118918 | 0.953122 |
>>>> Gallery 1 top1: 0.972289, top5: 0.981459, top10: 0.982855
>>>> Gallery 2 top1: 0.937571, top5: 0.962314, top10: 0.967834
>>>> Mean [Mean] top1: 0.954528, top5: 0.971665, top10: 0.975170
far | g1_tpir | g1_thresh | g2_tpir | g2_thresh | mean_tpir |
---|---|---|---|---|---|
0.0001 | 0.226077 | 0.840018 | 0.0411115 | 0.917713 | 0.133594 |
0.001 | 0.394338 | 0.788333 | 0.252569 | 0.812294 | 0.323454 |
0.01 | 0.916069 | 0.407797 | 0.894366 | 0.359157 | 0.905217 |
0.1 | 0.960526 | 0.270085 | 0.926532 | 0.269275 | 0.943529 |
1 | 0.972289 | 0.123328 | 0.937571 | 0.123652 | 0.95493 |
>>>> Gallery 1 top1: 0.973086, top5: 0.978868, top10: 0.981659
>>>> Gallery 2 top1: 0.934336, top5: 0.960030, top10: 0.966121
>>>> Mean [Mean] top1: 0.953262, top5: 0.969231, top10: 0.973710
far | g1_tpir | g1_thresh | g2_tpir | g2_thresh | mean_tpir |
---|---|---|---|---|---|
0.0001 | 0.225279 | 0.839388 | 0.0371146 | 0.921194 | 0.131197 |
0.001 | 0.399522 | 0.78508 | 0.186334 | 0.838523 | 0.292928 |
0.01 | 0.916467 | 0.408205 | 0.863533 | 0.408381 | 0.89 |
0.1 | 0.961324 | 0.263613 | 0.923297 | 0.264048 | 0.94231 |
1 | 0.973086 | 0.11342 | 0.934336 | 0.114038 | 0.953711 |
Do you have a plan to implement their(insightface) last work? It seems does not work as they claimed in their papers. https://github.com/deepinsight/insightface/issues/1801