geyingli / unif

基于 Tensorflow,仿 Scikit-Learn 设计的深度学习自然语言处理框架。支持 40 余种模型类,涵盖语言模型、文本分类、NER、MRC、知识蒸馏等各个领域
Apache License 2.0
114 stars 27 forks source link

对抗训练tf2 #6

Open luoda888 opened 3 years ago

luoda888 commented 3 years ago
class FreeAT(tf.keras.Model):
    def train_step(self, data):
        x, y = data
        last_r = 0.0
        last_r_slice = 0.0
        K = 3
        ep = 1e-3

        for t in range(K):
            with tf.GradientTape() as tape:
                y_pred = self(x, training=True)
                loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)        
            embedding_gradients = tape.gradient(loss, [self.trainable_variables[0]])[0]
            grad_values = tf.zeros_like(self.trainable_variables[0]) + embedding_gradients
            sign = tf.cast(tf.greater(grad_values, 0.0), tf.float32)
            r = last_r + tf.multiply(ep, sign) if t > 0 else \
                    tf.multiply(ep, sign)
            r *= tf.divide(ep, tf.norm(r))
            r_slice = tf.IndexedSlices(
                values=r,
                indices=embedding_gradients.indices,
                dense_shape=embedding_gradients.dense_shape)
            self.trainable_variables[0].assign_add(r_slice - last_r_slice)
            last_r = r
            last_r_slice = r_slice

        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

您好,这是我用tf2的特性重写的一个FreeAT,但是在实验效果上差了不少,是不是因为没有restore_grad的缘故?另外想请教,SMART与FreeLB的自定义r变量应该在tf2中如何实现呢,感激不尽

luoda888 commented 3 years ago

PGD 与 FGM都能正常work,但是FreeLB和Yopo、Smart这些不知道该咋改成tf2版本。感觉TF2版本是可以即插即用的,在keras模型定义层,用ADV_MODEL替换Model就好

geyingli commented 3 years ago

谢谢你的提问。tf2出来以后我本人一直是持观望态度的,依然在使用tf.compat.v1的api,使用静态图。所以具体地,代码上如何实现,的确是没办法给出什么建议的。但既然你会写PGD和FGM,在这基础上修改,写出FreeLB和SMART我想应该是不难的,不妨花点时间读一读论文的伪代码或UNIF的实现,尝试一下~

geyingli commented 3 years ago

回一下第一个问题,restore_grad肯定是要有的,r的实现其实是自定义一个随机变量作为r,不难。另外,无论是我这里的实现,还是别人用pytorch的实现,FreeAT是公认效果比较差的,时间急切的话可以先放弃哈

luoda888 commented 3 years ago

感谢,除了FreeLB和SMART、Yopo,您还有什么推荐的对抗学习的算法么

geyingli commented 3 years ago

对抗式学习的主流算法就是这几个了,在GLUE榜单上你可以看到。但是我有一段时间没有看论文了,不排除有更好的对抗式学习算法诞生的可能。按照表现设置优先级,SMART > FreeLB > Yopo > PGD > FGM