Closed PantherYan closed 4 years ago
Does this still happen when you use fp32? It could be caused by an improper handling of aggregating operations in fp16, e.g., when computing the norm, you should first convert the variable into fp32, then compute the norm and switch back to fp16. Could you post the code?
Thanks for you prompt reply. As I set as you suggested in the GLUE. I thought it maybe I set the learning rate of too large to be 1e-2. I will check this settings after the running. If it still goes errors , I will past the code and error out later.
@zhuchen03
Here is the error :
next(self.gen) File "/usr/local/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/usr/local/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/usr/local/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/usr/local/lib/python3.6/site-packages/apex/amp/scaler.py", line 176, in unscale_with_stashed out_scale/grads_have_scale, # 1./scale, ZeroDivisionError: float division by zero
Here is the modified code:
class FreeLB(): def init(self, args, model): self.model = model self.dropMask = None self.delta = None self.args = args self.dp_masks = None self.embeds_init = None
def attack(self, inputs, is_first_attack = False):
if is_first_attack:
if isinstance(self.model, torch.nn.DataParallel):
self.embeds_init = self.model.module.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))
else:
self.embeds_init = self.model.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))
if self.args.adv_init_mag > 0:
input_mask = inputs['attention_mask'].view(-1,inputs['attention_mask'].shape[-1]).to(self.embeds_init)
input_lengths = torch.sum(input_mask, 1)
# check the shape of the mask here..
if self.args.norm_type == "l2":
delta = torch.zeros_like(self.embeds_init).uniform_(-1,1) * input_mask.unsqueeze(2)
dims = input_lengths * self.embeds_init.size(-1)
mag = self.args.adv_init_mag / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
elif self.args.norm_type == "linf":
delta = torch.zeros_like(self.embeds_init).uniform_(-self.args.adv_init_mag, self.args.adv_init_mag) * input_mask.unsqueeze(2)
else:
delta = torch.zeros_like(self.embeds_init)
self.delta = delta.view(inputs['input_ids'].shape[0],inputs['input_ids'].shape[1],inputs['input_ids'].shape[2],delta.shape[-1])
self.delta.requires_grad_()
else: # not first attack
self.delta.requires_grad_()
if isinstance(self.model, torch.nn.DataParallel):
self.embeds_init = self.model.module.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))
else:
self.embeds_init = self.model.bert.embeddings.word_embeddings(inputs['input_ids'].view(-1,inputs['input_ids'].shape[-1]))
if isinstance(self.model, torch.nn.DataParallel):
inputs['inputs_embeds'] = self.delta+self.module.model.bert.embeddings.word_embeddings(inputs['input_ids'])
else:
inputs['inputs_embeds'] = self.delta+self.model.bert.embeddings.word_embeddings(inputs['input_ids'])
###### inputs['dp_masks'] = self.dp_masks ### outputs, dp_masks = models(**inputs)
outputs = self.model(**inputs)
return outputs
def update(self, ):
delta_grad = self.delta.grad.clone().detach()
self.delta = self.delta.view(delta_grad.size(0)*delta_grad.size(1), delta_grad.size(2), -1) # 6*512*768
if self.args.norm_type == "l2":
denorm = torch.norm(delta_grad.view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1)
denorm = torch.clamp(denorm, min=1e-8)
self.delta = (self.delta + self.args.adv_lr * delta_grad.view(delta_grad.size(0)*delta_grad.size(1), delta_grad.size(2), -1) / denorm).detach()
if self.args.adv_max_norm > 0:
delta_norm = torch.norm(self.delta.view(self.delta.size(0), -1).float(), p=2, dim=1).detach()
exceed_mask = (delta_norm > self.args.adv_max_norm).to(self.embeds_init)
reweights = (self.args.adv_max_norm / delta_norm * exceed_mask + (1-exceed_mask)).view(-1, 1, 1)
self.delta = (self.delta * reweights).detach()
elif self.args.norm_type == "linf":
denorm = torch.norm(delta_grad.view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1, p=float("inf")).view(-1, 1, 1)
denorm = torch.clamp(denorm, min=1e-8)
self.delta = (self.delta + self.args.adv_lr * delta_grad.view(delta_grad.size(0)*delta_grad.size(1), delta_grad.size(2), -1) / denorm).detach()
if self.args.adv_max_norm > 0:
self.delta = torch.clamp(self.delta, -self.args.adv_max_norm, self.args.adv_max_norm).detach()
else:
print("Norm type {} not specified.".format(self.args.norm_type))
exit()
self.delta = self.delta.view(delta_grad.size(0), delta_grad.size(1), delta_grad.size(2), -1) # 6*512*768 -> 2*3*512*768
Here is the training function:
if args.adv_step > 0 and not args.do_attack == 'freelb': args.adv_step = 1
for k in range(args.adv_step):
if args.do_attack == 'freelb':
outputs = flb.attack(inputs, is_first_attack=(k==0))
else:
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
##### share the loss_update ########
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
septr_loss[task_id] += loss.item()
septr_num[task_id] +=1
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
########################################
if args.do_attack == 'freelb':
flb.update()
Can you first try if it gives reasonable results under fp32 first? If so, then it's a problem with fp16 training. As I mentioned before, you have to change the variables into fp32 before feeding them into any aggregating operation and then switch the result back to fp16, since fp16 has a very limited range. In this example, at least you have to change
denorm = torch.norm(delta_grad.view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1)
into something like
denorm = torch.norm(delta_grad.float().view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1).to(delta_grad)
I didn't handle this part properly in the huggingface implementation, but you should pay attention to that.
Also, this normalization is kind of different from my original approach as defined here.
It goes fine with fp32.
The normalization operation is revised due to the data dimensions of my task.
My input has 3 dimensions with Numberchoice512, different from the GLUE number*512.
So to the embedding and bdnorm, I changed the dimension to compute the norm in coordinate with yours.
Number*choice*512*embeddingsize 2*3*512*768. To 6*512*768.
Here,
denorm = torch.norm(delta_grad.float().view(delta_grad.size(0)*delta_grad.size(1), -1), dim=1).view(-1, 1, 1).to(delta_grad)
you addedfloat()
and to(delta_grad)
different my original one.
I have another question about the dropout_mask. Your solution keeps all the same drop_mask in all the blocks.
Can it replace by a mask_fixed dropout operation in each initial() function of different bert_blocks(such as encoder)? IN this way,
import torch.nn as nn
class LockedDropout(nn.Module):
def __init__(self, p=0.5):
self.p = p
super().__init__()
def forward(self, x):
"""
Args:
x (:class:`torch.FloatTensor` [sequence length, batch size, rnn hidden size]): Input to
apply dropout too.
"""
if not self.training or not self.p:
return x
x = x.clone()
mask = x.new_empty(1, x.size(1), x.size(2), requires_grad=False).bernoulli_(1 - self.p)
mask = mask.div_(1 - self.p)
mask = mask.expand_as(x)
return x * mask
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'p=' + str(self.p) + ')'
First thanks for your wonderful work.
Does anyone meet the Nan error during the training-end epoch?
I embedding FreeLB as a plugin format(without handle dropout_mask): freelb.attack() freelb.update() to my network. But I face a problem of Nan error. The backbone is BERT.
At the beginning epoch, all works well, loss converges and accuracy boost. Till the loss converges to a very small scale, APX(fp16) scale the loss to very small scale about 1e-100, then the Nan error.