ajabri / videowalk

Repository for "Space-Time Correspondence as a Contrastive Random Walk" (NeurIPS 2020)
http://ajabri.github.io/videowalk
MIT License
266 stars 38 forks source link

Loss becomes "NAN" with Mixed Precision Training #11

Open PkuRainBow opened 3 years ago

PkuRainBow commented 3 years ago

Really impressive work!

We have tried to support DDP and Mixed Precision Training based on your implementation while meeting a NAN issues related to:

https://github.com/ajabri/videowalk/blob/2ac728d531acd4224177fa5fdbb8a21a7707e1af/code/model.py#L75

We have modified the above code as we used FP16:

A[torch.rand_like(A) < self.edgedrop_rate] = -65504

However, we find that the loss becomes to NAN soon. It would be great if you could give me any suggestions.

ajabri commented 3 years ago

Hi @PkuRainBow,

Thanks for your interest! Are you sure the NaN is a result of this line; i.e. have you tried training with --dropout 0?

I chose to implement dropout this way (setting the logits to a negative constant) for convenience -- you can also apply dropout on the actual transition matrices. You might also consider using a weaker negative constant -- in practice, something like -1000 should suffice.

PkuRainBow commented 3 years ago

@ajabri Thanks for your suggestions and we will update the feedback to you later!

I still meet the NAN errors when I set the dropout as zero with --dropout 0 or set the negative constant as -1000.

To clarify the reasons that cause the problem, I paste my modifications:

    scaler = torch.cuda.amp.GradScaler()

    for step, (video, orig) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        optimizer.zero_grad()

        start_time = time.time()
        video = video.to(device)

        with torch.cuda.amp.autocast():
            output, loss, diagnostics = model(video)
            loss = loss.mean()

        if checkpoint_fn is not None and np.random.random() < 0.005:
            checkpoint_fn()

        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
        metric_logger.meters['clips/s'].update(video.shape[0] / (time.time() - start_time))
        lr_scheduler.step()
ajabri commented 3 years ago

Were you able to address this issue?

pansanity666 commented 2 years ago

I also encountered this issue when using fp16. In my experiments, after I reduce the EPS from 1e-20 to 1e-7, the NAN in loss will disappear, but the performance is extremely poor. I think the log operation when calculating logits :

logits = torch.log(A+EPS).flatten(0, -2)

is redundant since the nn.CrossEntropy() API will automatically calculate the log. So what is the necessity of the EPS and extra log operation ?

ajabri commented 2 years ago

Hi @pansanity666

You're right that it's redundant because nn.CrossEntropy computes logsoftmax and then nll. I think I stored logprobs here because I had other losses before. You can just compute the trace of torch.log(A+EPS) to get the loss.

pansanity666 commented 2 years ago

Hi @pansanity666

You're right that it's redundant because nn.CrossEntropy computes logsoftmax and then nll. I think I stored logprobs here because I had other losses before. You can just compute the trace of torch.log(A+EPS) to get the loss.

Thank you for your reply. I tried removing the redundant log operation: i.e., replace logits = torch.log(A+EPS).flatten(0,-2) with logits = A.flatten(0,-2)

, and I find it gives similar performance (or slightly better), moreover, it fits fp16 well.

ajabri commented 2 years ago

Hi @pansanity666

I believe nn.CrossEntropy expects logits or log probabilities, so you will have to take the log. You can do the following instead:

logits = torch.log(A+EPS).flatten(0,-2)
loss = -torch.einsum('bnm,nm->bn', logits, torch.eye(A.shape[-1])).mean()

or considering skipping the log

loss = -torch.einsum('bnm,nm->bn', A.flatten(0,-2), torch.eye(A.shape[-1])).mean()