zhuchen03 / FreeLB

Adversarial Training for Natural Language Understanding
250 stars 41 forks source link

Does anyone meet the Nan error during the end epochs of training? #8

Closed PantherYan closed 4 years ago

PantherYan commented 4 years ago

First thanks for your wonderful work.

Does anyone meet the Nan error during the training-end epoch?

I embedding FreeLB as a plugin format(without handle dropout_mask): freelb.attack() freelb.update() to my network. But I face a problem of Nan error. The backbone is BERT.

At the beginning epoch, all works well, loss converges and accuracy boost. Till the loss converges to a very small scale, APX(fp16) scale the loss to very small scale about 1e-100, then the Nan error.

zhuchen03 commented 4 years ago

Does this still happen when you use fp32? It could be caused by an improper handling of aggregating operations in fp16, e.g., when computing the norm, you should first convert the variable into fp32, then compute the norm and switch back to fp16. Could you post the code?

PantherYan commented 4 years ago

Thanks for you prompt reply. As I set as you suggested in the GLUE. I thought it maybe I set the learning rate of too large to be 1e-2. I will check this settings after the running. If it still goes errors , I will past the code and error out later.

PantherYan commented 4 years ago

@zhuchen03

Here is the error :

next(self.gen) File "/usr/local/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/usr/local/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/usr/local/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/usr/local/lib/python3.6/site-packages/apex/amp/scaler.py", line 176, in unscale_with_stashed out_scale/grads_have_scale, # 1./scale, ZeroDivisionError: float division by zero

Here is the modified code:

class FreeLB(): def init(self, args, model): self.model = model self.dropMask = None self.delta = None self.args = args self.dp_masks = None self.embeds_init = None

def attack(self, inputs, is_first_attack = False):
    if is_first_attack:
        if isinstance(self.model, torch.nn.DataParallel):
            self.embeds_init = self.model.module.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))
        else:
            self.embeds_init = self.model.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))

        if self.args.adv_init_mag > 0:
            input_mask = inputs['attention_mask'].view(-1,inputs['attention_mask'].shape[-1]).to(self.embeds_init)
            input_lengths = torch.sum(input_mask, 1)
            # check the shape of the mask here..
            if self.args.norm_type == "l2":
                delta = torch.zeros_like(self.embeds_init).uniform_(-1,1) * input_mask.unsqueeze(2)
                dims = input_lengths * self.embeds_init.size(-1)
                mag = self.args.adv_init_mag / torch.sqrt(dims)
                delta = (delta * mag.view(-1, 1, 1)).detach()
            elif self.args.norm_type == "linf":
                delta = torch.zeros_like(self.embeds_init).uniform_(-self.args.adv_init_mag, self.args.adv_init_mag) * input_mask.unsqueeze(2)
        else:
            delta = torch.zeros_like(self.embeds_init)

        self.delta = delta.view(inputs['input_ids'].shape[0],inputs['input_ids'].shape[1],inputs['input_ids'].shape[2],delta.shape[-1])
        self.delta.requires_grad_()
    else:   # not first attack 
        self.delta.requires_grad_()
        if isinstance(self.model, torch.nn.DataParallel):
            self.embeds_init = self.model.module.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))
        else:
            self.embeds_init = self.model.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))

    if isinstance(self.model, torch.nn.DataParallel):
        inputs['inputs_embeds'] = self.delta+self.module.model.bert.embeddings.word_embeddings(inputs['input_ids'])
    else:
        inputs['inputs_embeds'] = self.delta+self.model.bert.embeddings.word_embeddings(inputs['input_ids'])
        ######        inputs['dp_masks'] = self.dp_masks ###            outputs, dp_masks = models(**inputs)

    outputs = self.model(**inputs)

    return outputs

def update(self, ):
    delta_grad = self.delta.grad.clone().detach()
    self.delta = self.delta.view(delta_grad.size(0)*delta_grad.size(1), delta_grad.size(2), -1)  # 6*512*768

    if self.args.norm_type == "l2":
        denorm = torch.norm(delta_grad.view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1)
        denorm = torch.clamp(denorm, min=1e-8)
        self.delta = (self.delta + self.args.adv_lr * delta_grad.view(delta_grad.size(0)*delta_grad.size(1), delta_grad.size(2), -1) / denorm).detach()
        if self.args.adv_max_norm > 0:
                delta_norm = torch.norm(self.delta.view(self.delta.size(0), -1).float(), p=2, dim=1).detach()
                exceed_mask = (delta_norm > self.args.adv_max_norm).to(self.embeds_init)
                reweights = (self.args.adv_max_norm / delta_norm * exceed_mask + (1-exceed_mask)).view(-1, 1, 1)
                self.delta = (self.delta * reweights).detach()
    elif self.args.norm_type == "linf":
        denorm = torch.norm(delta_grad.view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1, p=float("inf")).view(-1, 1, 1)
        denorm = torch.clamp(denorm, min=1e-8)
        self.delta = (self.delta + self.args.adv_lr * delta_grad.view(delta_grad.size(0)*delta_grad.size(1), delta_grad.size(2), -1) / denorm).detach()
        if self.args.adv_max_norm > 0:
            self.delta = torch.clamp(self.delta, -self.args.adv_max_norm, self.args.adv_max_norm).detach()
    else:
        print("Norm type {} not specified.".format(self.args.norm_type))
        exit()
    self.delta = self.delta.view(delta_grad.size(0), delta_grad.size(1), delta_grad.size(2), -1)  # 6*512*768 -> 2*3*512*768

Here is the training function:

if args.adv_step > 0 and not args.do_attack == 'freelb': args.adv_step = 1

        for k in range(args.adv_step):
            if args.do_attack == 'freelb':
                outputs = flb.attack(inputs, is_first_attack=(k==0))
            else:
                outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            #####    share the loss_update    ######## 
            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu parallel training
            septr_loss[task_id] += loss.item()
            septr_num[task_id] +=1
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)   
            ########################################  

            if args.do_attack == 'freelb':
                flb.update()  
zhuchen03 commented 4 years ago

Can you first try if it gives reasonable results under fp32 first? If so, then it's a problem with fp16 training. As I mentioned before, you have to change the variables into fp32 before feeding them into any aggregating operation and then switch the result back to fp16, since fp16 has a very limited range. In this example, at least you have to change

denorm = torch.norm(delta_grad.view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1)

into something like

denorm = torch.norm(delta_grad.float().view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1).to(delta_grad)

I didn't handle this part properly in the huggingface implementation, but you should pay attention to that.

Also, this normalization is kind of different from my original approach as defined here.

PantherYan commented 4 years ago

It goes fine with fp32.

The normalization operation is revised due to the data dimensions of my task.
My input has 3 dimensions with Numberchoice512, different from the GLUE number*512. So to the embedding and bdnorm, I changed the dimension to compute the norm in coordinate with yours.

Number*choice*512*embeddingsize 2*3*512*768. To 6*512*768.

Here,

denorm = torch.norm(delta_grad.float().view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1).to(delta_grad)

you addedfloat()and to(delta_grad) different my original one.

  1. Is float() to keep computing in fp32?
  2. Is to(delta_grad) is keep the same cuda/dtype?

I have another question about the dropout_mask. Your solution keeps all the same drop_mask in all the blocks.

Can it replace by a mask_fixed dropout operation in each initial() function of different bert_blocks(such as encoder)? IN this way,

import torch.nn as nn
class LockedDropout(nn.Module):
    def __init__(self, p=0.5):
        self.p = p
        super().__init__()

   def forward(self, x):
        """
        Args:
            x (:class:`torch.FloatTensor` [sequence length, batch size, rnn hidden size]): Input to
                apply dropout too.
        """
        if not self.training or not self.p:
            return x
        x = x.clone()
        mask = x.new_empty(1, x.size(1), x.size(2), requires_grad=False).bernoulli_(1 - self.p)
        mask = mask.div_(1 - self.p)
        mask = mask.expand_as(x)
        return x * mask

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'p=' + str(self.p) + ')'