Closed weWillGetThere closed 1 year ago
hello 你好,抱歉刚看到。三个问题的回复如下:
所以你可以理解 mask_k_npl = F.where((mask_k==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask_k), mask_k) 这个code的意思是mask_k==1 and softmax_pred 值大时,把这样的 prob mask掉, 但其实这种情况出现的概率不高,加上了训练会更稳定。
如果还有疑问的话,欢迎邮件咨询 yhao.chen0617@gmail.com
代码勘误:根据论文3.2和3.3节的描述(公式(3)、公式(9)),弱数据增强的预测概率q^i_c(即pseudo_label)中不低于p_cutoff的部分s^i_c为1,但代码中用的是强数据增强的预测概率(即softmax_pred)。以下是我从flexmatch_util.py中摘取的例子:
mask_k_npl = F.where((mask_k==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask_k), mask_k)
另外,我在拜读完您团队的论文(CVPR 2023 PDF)以及代码后,对于弱数据增强预测概率有多个不低p_cutoff的值的情况产生了疑惑。
在求解ANL(即loss_npl)时,论文中说掩膜是topk之后的部分,而代码(即flexmatch_util.py,后同)中还剔除了topk后不低于p_cutoff的部分,即提到了:
mask_k_npl = F.where((mask_k==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask_k), mask_k)
在求解EML(即loss_em)时,论文中的公式(5)说分母是所有u^i_c=1的个数,即不包含预测概率非top1但不低于p_cutoff的部分;但看前面的论述部分和3.4节的"This means the count of non-target class in examples with pseudo-label is k − 1 instead of C − 1",以及您的代码:
mask_k = F.scatter(F.ones_like(pseudo_label), 1, topk, F.zeros_like(topk))
mask_k = F.scatter(mask_k, 1, label.reshape(-1,1), F.ones_like(label.reshape(-1,1)))
yg = F.cond_take(mask_k.astype('bool'), softmax_pred)[0].reshape(pred_w.shape[0],-1).sum(axis=-1,keepdims=True)
soft_ml = F.broadcast_to((1-yg+1e-7)/(k-1), pred_s.shape)
,您应该想表达的意思是分母包含该部分,即统一为C-1或k-1。不知道我的理解是否正确? 而公式(6)使用了u^i_c,即不包含预测概率非top1但高于p_cutoff的部分,您的代码中也进行了剔除:mask = F.where((mask==1)&(softmax_pred>p_cutoff**2), F.zeros_like(mask), mask)
我想询问一下,为什么两个公式对于该部分的处理不统一?既然公式(6)要剔除该部分,为什么公式(5)中不也将其剔除?又或者公式(5)在包含了该部分的情况下平均概率,公式(6)为什么不将其包含?期待您的回复。