Open PkuRainBow opened 3 years ago
Hi @PkuRainBow,
Thanks for your interest! Are you sure the NaN is a result of this line; i.e. have you tried training with --dropout 0
?
I chose to implement dropout this way (setting the logits to a negative constant) for convenience -- you can also apply dropout on the actual transition matrices. You might also consider using a weaker negative constant -- in practice, something like -1000
should suffice.
@ajabri Thanks for your suggestions and we will update the feedback to you later!
I still meet the NAN errors when I set the dropout as zero with --dropout 0 or set the negative constant as -1000.
To clarify the reasons that cause the problem, I paste my modifications:
scaler = torch.cuda.amp.GradScaler()
for step, (video, orig) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
optimizer.zero_grad()
start_time = time.time()
video = video.to(device)
with torch.cuda.amp.autocast():
output, loss, diagnostics = model(video)
loss = loss.mean()
if checkpoint_fn is not None and np.random.random() < 0.005:
checkpoint_fn()
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
metric_logger.meters['clips/s'].update(video.shape[0] / (time.time() - start_time))
lr_scheduler.step()
Were you able to address this issue?
I also encountered this issue when using fp16. In my experiments, after I reduce the EPS from 1e-20 to 1e-7, the NAN in loss will disappear, but the performance is extremely poor. I think the log operation when calculating logits :
logits = torch.log(A+EPS).flatten(0, -2)
is redundant since the nn.CrossEntropy() API will automatically calculate the log. So what is the necessity of the EPS and extra log operation ?
Hi @pansanity666
You're right that it's redundant because nn.CrossEntropy computes logsoftmax and then nll. I think I stored logprobs here because I had other losses before. You can just compute the trace of torch.log(A+EPS)
to get the loss.
Hi @pansanity666
You're right that it's redundant because nn.CrossEntropy computes logsoftmax and then nll. I think I stored logprobs here because I had other losses before. You can just compute the trace of
torch.log(A+EPS)
to get the loss.
Thank you for your reply.
I tried removing the redundant log operation:
i.e., replace
logits = torch.log(A+EPS).flatten(0,-2)
with
logits = A.flatten(0,-2)
, and I find it gives similar performance (or slightly better), moreover, it fits fp16 well.
Hi @pansanity666
I believe nn.CrossEntrop
y expects logits or log probabilities, so you will have to take the log.
You can do the following instead:
logits = torch.log(A+EPS).flatten(0,-2)
loss = -torch.einsum('bnm,nm->bn', logits, torch.eye(A.shape[-1])).mean()
or considering skipping the log
loss = -torch.einsum('bnm,nm->bn', A.flatten(0,-2), torch.eye(A.shape[-1])).mean()
Really impressive work!
We have tried to support DDP and Mixed Precision Training based on your implementation while meeting a NAN issues related to:
https://github.com/ajabri/videowalk/blob/2ac728d531acd4224177fa5fdbb8a21a7707e1af/code/model.py#L75
We have modified the above code as we used FP16:
A[torch.rand_like(A) < self.edgedrop_rate] = -65504
However, we find that the loss becomes to NAN soon. It would be great if you could give me any suggestions.