megvii-research / FullMatch

Official implementation of FullMatch (CVPR2023)
Apache License 2.0
44 stars 2 forks source link

勘误以及关于弱数据增强预测概率有多个不低于p_cutoff的值的情况的疑问 #1

Closed weWillGetThere closed 1 year ago

weWillGetThere commented 1 year ago

代码勘误:根据论文3.2和3.3节的描述(公式(3)、公式(9)),弱数据增强的预测概率q^i_c(即pseudo_label)中不低于p_cutoff的部分s^i_c为1,但代码中用的是强数据增强的预测概率(即softmax_pred)。以下是我从flexmatch_util.py中摘取的例子: mask_k_npl = F.where((mask_k==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask_k), mask_k)

另外,我在拜读完您团队的论文(CVPR 2023 PDF)以及代码后,对于弱数据增强预测概率有多个不低p_cutoff的值的情况产生了疑惑。

  1. 在求解ANL(即loss_npl)时,论文中说掩膜是topk之后的部分,而代码(即flexmatch_util.py,后同)中还剔除了topk后不低于p_cutoff的部分,即提到了: mask_k_npl = F.where((mask_k==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask_k), mask_k)

  2. 在求解EML(即loss_em)时,论文中的公式(5)说分母是所有u^i_c=1的个数,即不包含预测概率非top1但不低于p_cutoff的部分;但看前面的论述部分和3.4节的"This means the count of non-target class in examples with pseudo-label is k − 1 instead of C − 1",以及您的代码: mask_k = F.scatter(F.ones_like(pseudo_label), 1, topk, F.zeros_like(topk)) mask_k = F.scatter(mask_k, 1, label.reshape(-1,1), F.ones_like(label.reshape(-1,1))) yg = F.cond_take(mask_k.astype('bool'), softmax_pred)[0].reshape(pred_w.shape[0],-1).sum(axis=-1,keepdims=True) soft_ml = F.broadcast_to((1-yg+1e-7)/(k-1), pred_s.shape) ,您应该想表达的意思是分母包含该部分,即统一为C-1或k-1。不知道我的理解是否正确? 而公式(6)使用了u^i_c,即不包含预测概率非top1但高于p_cutoff的部分,您的代码中也进行了剔除: mask = F.where((mask==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask), mask) 我想询问一下,为什么两个公式对于该部分的处理不统一?既然公式(6)要剔除该部分,为什么公式(5)中不也将其剔除?又或者公式(5)在包含了该部分的情况下平均概率,公式(6)为什么不将其包含?

期待您的回复。

yhaochen0617 commented 1 year ago

hello 你好,抱歉刚看到。三个问题的回复如下:

  1. ”弱数据增强的预测概率q^i_c(即pseudo_label)中不低于p_cutoff的部分s^i_c为1,但代码中用的是强数据增强的预测概率(即softmax_pred)“,你说的这部分是fixmatch的过程,不在这个函数里面,见consistency_loss.
  2. ”而代码(即flexmatch_util.py,后同)中还剔除了topk后不低于p_cutoff的部分“,这个是为了训练稳定性用的,在多个setting下跑多个seed时,我们发现偶尔会选到 strong-aug中得分特别高的从而训练不太稳定,所以加了一个工程过滤。但其实出现这种情况的概率较低;
  3. ”论文中的公式(5)说分母是所有u^i_c=1的个数,即不包含预测概率非top1但不低于p_cutoff的部分;“,这个你理解错了,U_i==1的个数就是k-1(when using ANL),同样工程过滤也应用在这部分了。

所以你可以理解 mask_k_npl = F.where((mask_k==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask_k), mask_k) 这个code的意思是mask_k==1 and softmax_pred 值大时,把这样的 prob mask掉, 但其实这种情况出现的概率不高,加上了训练会更稳定。

如果还有疑问的话,欢迎邮件咨询 yhao.chen0617@gmail.com