Open iamkucuk opened 3 years ago
Hi @iamkucuk This is somewhat expected and it is not so obvious how to solve it. You may know, in DDP each subprocess sees a different split of the data, and conditionally skipping the training step would cause the processes to get out of sync. I'm not exactly sure why training freezes but it must have todo with the processes being out of sync or waiting to sync gradients that were not computed in some processes.
@SeanNaren do you have an idea what we could do when the user wants to skip a step in DDP?
@iamkucuk In the meantime I would investigate why the loss in your output is nan in the first place, I am sure it can be avoided entirely. Regarding the second part of your condition (random), what purpose does it serve? Can you not just reduce the data by 5% at the beginning?
Hi @awaelchli I am not sure of it yet, but it may be an exploding gradient issue with a single batch generates powerful gradients. It happens rarely but the model is learning perfectly when it doesn't. I tried clipping gradients which seriously impacts the speed of learning process in my case, but seemingly solves the problem. Another approach I tried is accumulating gradients, which I think reduces the effect of a batch causing low quality gradients problem I mentioned before, and it did reduce the nan loss error significantly. However, the problem still persists.
Another approach I tried was equalizing the loss to torch.tensor(0)
and I thought it could help me not to update my model for that batch. However, it is causing a loss of the computation graph.
The random.random() < .05
condition just serves to reproduce the nan loss error more often, as it happens very rarely. It has nothing to do with my training procedure.
Hey @awaelchli,
I wonder if we could do the following:
class LightningModule
# introduce this parameters so we don't force synchronisation for all users on every step.
might_training_step_return_none = False
call training_step process_i -> return None process_j -> return loss
if might_training_step_return_none:
should_skip = self.trainer.accelerator_backend.all_reduce(output is None) > 0
if should_skip:
return
yes, this is totally feasible. with this solution we will be on the "safe side" and always skip in all processes when at least one says skip. Actually, we alreay have this logic for early stopping, see the accelerator, it has reduce for early stopping.
the other option is to skip only if all processes say skip. I can't come up with a use case but surely someone will. with that in mind, I would aim for a parameter like:
# sum or prod
# sum > 0 means skip
# prod > 0 means skip
# int: process with this index decides (broadcast)
skip_decision: Union[str, int] = "sum"
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!
Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps.
Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps.
I think this is already resolved in #5359 and merged into 1.1.x. If so, it can be closed.
(Though it might still be unresolved when using DDP-Sharded or if you accidentally log the NaN loss, causing a freeze)
@AAnoosheh no, #5359 just closed but couldn't be merged yet. We still need to work on it and figure out a solution.
Hi, I'm in a similar situation. My batches are formed from the output of an object detector, so sometimes the batch will essentially be of size zero (I can't think of a good way to explain this but just trust it makes sense). In this case, I would like to return None from train_step, or least return some kind of zero loss-tensor with zero gradients. If its not easy to return None, is there some way to artificially construct a zero-tensor with all the appropriate gradients present so that the DDP sync will work?
Hi, is there a solution to this issue?
No @AsaphLightricks. Returning None
in traning_step
with DDP is currently unsupported.
Also looking for a patch for this! (My use case is pretty much exactly what @kazimpal87 described above)
I also need this feature. Is there a known workaround?
Hi, any updates on this?
Second this, this is basically a must have feature.
Even totally skipping any update across all workers when just one returns None
will be okay for me.
Right now the training freezes and thus lightning + DDP cannot be used for a lot of tasks e.g. automatic speech recognition (ASR) where this issue is common due to the loss used + mixed precision.
Can't it be done here ? basically skipping reduce altogether and returning None ? (not very familiar with Lightning sorry) https://github.com/Lightning-AI/lightning/blob/7268670d1aec7d40241962c5bc81b0b871de0a57/src/lightning/pytorch/strategies/ddp.py#LL314C5-L331C1
Hi! I'll give this a look :)
Great thanks, for now it can be solved with manual optimization BTW if someone has this problem.
On my side the training freezes not on the iteration when None
is returned but the one right after, when calling in pytorch_lightning/loops/training_epoch_loop.py
:
batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
the iteration that returns None goes forward actually without a problem apparently and even updates the progress bar.
I have a fix but it is inside the training loop ( it is similar to @tchaton suggestion above):
def _reduce_loss(self, utt_id, batch_id, loss, reduction="mean"):
assert loss.shape[0] == len(utt_id), "loss must be reduced to batch dimension !"
mask_nan_inf = torch.logical_or(torch.isnan(loss), ~torch.isfinite(loss))
if torch.any(mask_nan_inf):
where_invalid = torch.where(mask_nan_inf)[0]
for indx in range(where_invalid.shape[0]):
inv_indx = where_invalid[indx].item()
log.info(
f"NaN loss in batch {batch_id} of epoch {self.current_epoch}, for utt_id {utt_id[inv_indx]}"
)
# if any is invalid then we must flag this to all processes
flag_skip = torch.ones((), device=loss.device, dtype=torch.bool)
else:
flag_skip = torch.zeros((), device=loss.device, dtype=torch.bool)
# sub-optimal but will do,
# till they fix it in https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013
world_size = torch_dist.get_world_size()
torch_dist.barrier()
# now gather
result = [torch.zeros_like(flag_skip) for _ in range(world_size)]
torch_dist.all_gather(result, flag_skip)
any_invalid = torch.sum(torch.stack(result)).bool().item()
if any_invalid:
if self.nan_countdown >= 100:
raise RuntimeError(
"Too many NaNs loss iterations encountered, stopping !"
)
self.nan_countdown += 1
return None
else:
self.nan_countdown = 1
return loss.mean() if reduction == "mean" else loss.sum()
Basically I gather a flag across all DDP workers, if any of the workers set the flag all workers must return None. If all return None there is not anymore freezing. But it would be neat if this stuff is handled inside lightning. I feel here I just add unnecessary synchronization.
I am sure for someone more familiar with lightning background magic it must be easy to do add something similar in the right place.
@popcornell, thanks for sharing! This logic seems to work for OOM errors occurring in the forward pass. However, when I tried this logic similarly in Lightning
's backward
hook, I experienced the dreaded DDP freezing issue, suggesting that my DDP ranks have still fallen out of sync. My backward
code is as follows. Does anyone here have any ideas about what might be causing the ranks to become out of sync with each other? Notably, my code never logs the warning "Ran out of memory in the backward pass. Skipping batches for all ranks."
, which suggests that at least one of the ranks is never hitting the barrier (i.e., torch_dist.barrier()
) once all the others have.
def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Any):
"""Overrides Lightning's `backward` step to add an out-of-memory (OOM) check."""
# by default, do not skip the current batch
skip_flag = torch.zeros(
(), device=self.device, dtype=torch.bool, requires_grad=False
) # NOTE: for skipping batches in a multi-device setting
try:
loss.backward(*args, **kwargs, retain_graph=False)
except RuntimeError as e:
skip_flag = torch.ones((), device=self.device, dtype=torch.bool, requires_grad=False)
if "out of memory" in str(e):
log.warning(
f"Ran out of memory in the backward pass, where `torch_dist.is_initialized` is {torch_dist.is_initialized()}. Skipping batch due to: {e}"
)
if not torch_dist.is_initialized():
# NOTE: for skipping batches in a single-device setting
del loss # delete the computation graph
for p in self.net.parameters():
if p.grad is not None:
del p.grad
# NOTE: for skipping batches in a multi-device setting
# credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
if torch_dist.is_initialized():
# if any rank skips a batch, then all other ranks need to skip
# their batches as well so DDP can properly keep all ranks synced
world_size = torch_dist.get_world_size()
torch_dist.barrier()
result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
torch_dist.all_gather(result, skip_flag)
any_skipped = torch.sum(torch.stack(result)).bool().item()
if any_skipped:
del loss # delete the computation graph
for p in self.net.parameters():
if p.grad is not None:
del p.grad
log.warning(
"Ran out of memory in the backward pass. Skipping batches for all ranks."
)
Following up on my previous comment, it seems like https://github.com/pytorch/pytorch/issues/18853#issuecomment-698386652 discusses a related issue. In this case, it may be the case that the OOM errors I am seeing are causing loss.backward
to be successfully called on some variables of the total autograd graph (but not all of them, since some do not complete their loss.backward
call). Said another way, is there a way with DDP to manually tell it to effectively mark certain autograd variables as "having completed their backward pass" even if they actually haven't? That way, DDP ranks would never become out of sync with each other.
Your code probably hangs because one process gets the exception and the other does not, and so the one that did not raise will hang at the barrier while watching the other die (sad). When you catch it and run some logic, you will need to synchronize this decision. Make it so that either all of them don't raise, or all of them raise, but a mix is not allowed.
@awaelchli, thanks for your suggestion. I agree: if a DDP rank raises an exception (e.g., e
in my code above), this kind of issue will definitely occur. However, what confuses me is my warning log log.warning(f"Ran out of memory in the backward pass. Skipping batch due to: {e}")
gets printed to my respective log file right before my entire training job freezes. This implies that that specific DDP rank that OOM'd indeed went down the if
branch of my exception handling code, which should not raise the underlying exception e
. To check this though, I am now running a test of this same scenario where I remove the else: raise e
branch of the exception handling logic to see if indeed some DDP ranks were raising the actual exception when they shouldn't have been. I'll report back with my findings.
@awaelchli, sadly, removing the else: raise e
branch in my exception handling logic above does not resolve the DDP rank freezing issue I am facing. What else might be causing one of the ranks to die before it reaches the torch_dist.barrier()
call? I can confirm that I can see exactly one warning log being issued before my training job completely freezes across all ranks:
Ran out of memory in the backward pass, where `torch_dist.is_initialized` is True. Skipping batch due to: CUDA out of memory. Tried to allocate 5.18 GiB (GPU 0; 79.15 GiB total capacity; 72.98 GiB already allocated; 139.25 MiB free; 78.39 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.
Since torch_dist.is_initialized
is confirmed to be set to True
here, any DDP rank that logs such a message should not return None
, thereby preventing such a rank from "dying early".
Sadly, it looks like I may have found a culprit for this issue: wandb
(using the latest version 0.15.12
- reference: https://github.com/pytorch/xla/issues/1562#issuecomment-579765003).
After disabling my WandbLogger
temporarily and rerunning my model training script from the latest checkpoint (shortly after which my script would normally OOM during the loss.backward()
call), I am now not seeing the original OOM error which would normally freeze all my DDP ranks from there on. I don't believe I am having wandb
watch
my model's computational graph in any way, unless similar behavior is enabled by default with Lightning's WandbLogger
.
Has anyone else experienced this issue and found a way around it (besides perhaps setting offline=true
for WandbLogger
and manually syncing local logs to remote later on)? @sydholl has the wandb
team seen any issues like this occurring before (https://github.com/wandb/wandb/issues/2091)?
🐛 Bug
Returning None from training_step with multi GPU DDP training freezes the training without exception
To Reproduce
Starting multi-gpu training with a None-returning training_step function.
Example training_step function:
Example trainer:
Expected behavior
To continue training with skipping the current batch as pointed out at here.
Environment
No specific environment is needed to reproduce this bug.
Additional context
This issue was mentioned here: #4956 but not with specifics.
Note: By the time this issue being investigated, a help for a workaround would be great!