xiangking / ark-nlp

A private nlp coding package, which quickly implements the SOTA solutions.
Apache License 2.0
310 stars 65 forks source link

如何有效添加 对抗训练的 pgd。参考:https://github.com/xiangking/ark-nlp/issues/46 #50

Closed browserliu closed 2 years ago

browserliu commented 2 years ago

参考:https://github.com/xiangking/ark-nlp/issues/46 查询发现是loss 释放了,采用了 loss.backward(retain_graph=True) 的方法,替换所有的 loss.backward()为loss.backward(retain_graph=True)。能够正常训练。 但F1指标效果低于 fgm。如果有效添加 对抗训练的 pgd

def _on_backward( self, inputs, outputs, logits, loss, gradient_accumulation_steps=1, **kwargs ):

    # 如果GPU数量大于1
    if self.n_gpu > 1:
        loss = loss.mean()
    # 如果使用了梯度累积,除以累积的轮数
    if gradient_accumulation_steps > 1:
        loss = loss / gradient_accumulation_steps

    loss.backward()
    self.pgd.backup_grad()
    # 对抗训练
    for t in range(self.K):
        self.pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data
        if t != self.K-1:
            self.module.zero_grad()
        else:
            self.pgd.restore_grad()
        logits = self.module(**inputs)
        logits, loss_adv = self._get_train_loss(inputs, outputs, **kwargs)
        # 如果GPU数量大于1
        if self.n_gpu > 1:
            loss_adv = loss_adv.mean()
        # 如果使用了梯度累积,除以累积的轮数
        if gradient_accumulation_steps > 1:
            loss_adv = loss_adv / gradient_accumulation_steps
        loss_adv.backward()
    self.pgd.restore() # 恢复embedding参数 

    self._on_backward_record(loss, **kwargs)

    return loss

在loss.backward()后面添加PGD对抗训练报错

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

xiangking commented 2 years ago

出现issue #46 的报错属于不正常的情况,经过复现问题,建议采用如下写法使用PGD

from torch.utils.data import DataLoader
from ark_nlp.factory.optimizer import get_optimizer
from ark_nlp.model.ner.global_pointer_bert import Task

class AttackTask(Task):

    def _on_train_begin(
        self,
        train_data,
        validation_data,
        batch_size,
        lr,
        params,
        shuffle,
        num_workers=0,
        train_to_device_cols=None,
        **kwargs
    ):
        if hasattr(train_data, 'id2cat'):
            self.id2cat = train_data.id2cat
            self.cat2id = {v_: k_ for k_, v_ in train_data.id2cat.items()}

        # 在初始化时会有class_num参数,若在初始化时不指定,则在训练阶段从训练集获取信息
        if self.class_num is None:
            if hasattr(train_data, 'class_num'):
                self.class_num = train_data.class_num
            else:
                warnings.warn("The class_num is None.")

        if train_to_device_cols is None:
            self.train_to_device_cols = train_data.to_device_cols
        else:
            self.train_to_device_cols = train_to_device_cols

        train_generator = DataLoader(
            train_data,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            collate_fn=self._train_collate_fn
        )
        self.train_generator_lenth = len(train_generator)

        self.optimizer = get_optimizer(self.optimizer, self.module, lr, params)
        self.optimizer.zero_grad()

        self.module.train()

        self.pgd = PGD(self.module)
        self.K = 3

        self._on_train_begin_record(**kwargs)

        return train_generator

    def _on_backward(
        self,
        inputs,
        outputs,
        logits,
        loss,
        gradient_accumulation_steps=1,
        **kwargs
    ):

        # 如果GPU数量大于1
        if self.n_gpu > 1:
            loss = loss.mean()
        # 如果使用了梯度累积,除以累积的轮数
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()

        self.pgd.backup_grad()

        for t in range(self.K):
            self.pgd.attack(is_first_attack=(t==0)) 
            if t != self.K-1:
                self.optimizer.zero_grad()
            else:
                self.pgd.restore_grad()

            logits = self.module(**inputs)
            _, attck_loss = self._get_train_loss(inputs, logits, **kwargs)

            attck_loss.backward()

        self.pgd.restore() 

        self._on_backward_record(loss, **kwargs)

        return loss
paulpaulzhang commented 2 years ago

pgd的backup_grad与restore_grad这两个函数写的有问题,当梯度为空的时候会直接报错,需要加一个判空的操作

 # 原
def backup_grad(self):
    for name, param in self.module.named_parameters():
        if param.requires_grad:
            self.grad_backup[name] = param.grad.clone()

def restore_grad(self):
    for name, param in self.module.named_parameters():
        if param.requires_grad:
            param.grad = self.grad_backup[name]

 # 现
def backup_grad(self):
    for name, param in self.module.named_parameters():
        if param.requires_grad and param.grad is not None:
            self.grad_backup[name] = param.grad.clone()

def restore_grad(self):
    for name, param in self.module.named_parameters():
        if param.requires_grad and param.grad is not None:
            param.grad = self.grad_backup[name]
xiangking commented 2 years ago

存在该情况,加入判空逻辑已经在更新计划中