xiangking / ark-nlp

A private nlp coding package, which quickly implements the SOTA solutions.
Apache License 2.0
311 stars 64 forks source link

添加PGM对抗训练报错 #46

Closed yysirs closed 2 years ago

yysirs commented 2 years ago
def _on_backward(
        self,
        inputs,
        outputs,
        logits,
        loss,
        gradient_accumulation_steps=1,
        **kwargs
    ):

        # 如果GPU数量大于1
        if self.n_gpu > 1:
            loss = loss.mean()
        # 如果使用了梯度累积,除以累积的轮数
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps

        loss.backward()
        self.pgd.backup_grad()
        # 对抗训练
        for t in range(self.K):
            self.pgd.attack(is_first_attack=(t==0)) # 在embedding上添加对抗扰动, first attack时备份param.data
            if t != self.K-1:
                self.module.zero_grad()
            else:
                self.pgd.restore_grad()
            logits = self.module(**inputs)
            logits, loss_adv = self._get_train_loss(inputs, outputs, **kwargs)
            # 如果GPU数量大于1
            if self.n_gpu > 1:
                loss_adv = loss_adv.mean()
            # 如果使用了梯度累积,除以累积的轮数
            if gradient_accumulation_steps > 1:
                loss_adv = loss_adv / gradient_accumulation_steps
            loss_adv.backward()
        self.pgd.restore() # 恢复embedding参数 

        self._on_backward_record(loss, **kwargs)

        return loss

在loss.backward()后面添加PGD对抗训练报错

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

请问作者大大这个是怎么回事?

browserliu commented 2 years ago

请问,这个问题是如何解决的? 遇到同样的报错。 查询发现是loss 释放了,采用了 loss.backward(retain_graph=True) 的方法,但效果低于 fgm。