t6am3 / public_tianchi_yiqing_nlp

天池疫情公益文本相似对比大赛
20 stars 11 forks source link

关于对抗训练 #3

Open senchfu opened 4 years ago

senchfu commented 4 years ago

您好,关于对抗训练fgm的代码

if args.adv_fgm:
                fgm.attack() # 在embedding上添加对抗扰动
                loss_adv = model(**inputs)[0]
                loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                fgm.restore() # 恢复embedding参数

如果使用多卡的话,例如n_gpu=2,这里的loss_adv是不是要取一下平均,也就是loss_adv=loss_adv.mean()。期待您的回复,谢谢!