Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.51k stars 3.39k forks source link

Returning None from training_step with multi GPU DDP training #5243

Open iamkucuk opened 3 years ago

iamkucuk commented 3 years ago

🐛 Bug

Returning None from training_step with multi GPU DDP training freezes the training without exception

To Reproduce

Starting multi-gpu training with a None-returning training_step function.

Example training_step function:

    def training_step(self, batch, batch_idx):
        data, target = batch
        model_outputs = self.forward(images)
        loss = calc_loss(model_outputs, target)

        if torch.isnan(loss) or random.random() < .05:
            return None

        return loss

Example trainer:

 trainer = Trainer(
    gpus=2,
    distributed_backend="ddp"
)

Expected behavior

To continue training with skipping the current batch as pointed out at here.

Environment

No specific environment is needed to reproduce this bug.

Additional context

This issue was mentioned here: #4956 but not with specifics.

Note: By the time this issue being investigated, a help for a workaround would be great!

awaelchli commented 3 years ago

Hi @iamkucuk This is somewhat expected and it is not so obvious how to solve it. You may know, in DDP each subprocess sees a different split of the data, and conditionally skipping the training step would cause the processes to get out of sync. I'm not exactly sure why training freezes but it must have todo with the processes being out of sync or waiting to sync gradients that were not computed in some processes.

@SeanNaren do you have an idea what we could do when the user wants to skip a step in DDP?

@iamkucuk In the meantime I would investigate why the loss in your output is nan in the first place, I am sure it can be avoided entirely. Regarding the second part of your condition (random), what purpose does it serve? Can you not just reduce the data by 5% at the beginning?

iamkucuk commented 3 years ago

Hi @awaelchli I am not sure of it yet, but it may be an exploding gradient issue with a single batch generates powerful gradients. It happens rarely but the model is learning perfectly when it doesn't. I tried clipping gradients which seriously impacts the speed of learning process in my case, but seemingly solves the problem. Another approach I tried is accumulating gradients, which I think reduces the effect of a batch causing low quality gradients problem I mentioned before, and it did reduce the nan loss error significantly. However, the problem still persists.

Another approach I tried was equalizing the loss to torch.tensor(0) and I thought it could help me not to update my model for that batch. However, it is causing a loss of the computation graph.

The random.random() < .05 condition just serves to reproduce the nan loss error more often, as it happens very rarely. It has nothing to do with my training procedure.

tchaton commented 3 years ago

Hey @awaelchli,

I wonder if we could do the following:

class LightningModule

      # introduce this parameters so we don't force synchronisation for all users on every step.
      might_training_step_return_none = False

call training_step process_i -> return None process_j -> return loss


if might_training_step_return_none:
       should_skip = self.trainer.accelerator_backend.all_reduce(output is None) > 0
       if should_skip:
              return
awaelchli commented 3 years ago

yes, this is totally feasible. with this solution we will be on the "safe side" and always skip in all processes when at least one says skip. Actually, we alreay have this logic for early stopping, see the accelerator, it has reduce for early stopping.

the other option is to skip only if all processes say skip. I can't come up with a use case but surely someone will. with that in mind, I would aim for a parameter like:

# sum or prod
# sum > 0 means skip
# prod > 0 means skip
# int: process with this index decides (broadcast)
skip_decision: Union[str, int] = "sum"
stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

stale[bot] commented 3 years ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

aced125 commented 3 years ago

Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps.

AAnoosheh commented 3 years ago

Would love to know when this feature is on the master branch! Deepspeed training seems to give infs for some reason (even on the minGPT example), so it would be cool to skip these steps.

I think this is already resolved in #5359 and merged into 1.1.x. If so, it can be closed.
(Though it might still be unresolved when using DDP-Sharded or if you accidentally log the NaN loss, causing a freeze)

awaelchli commented 3 years ago

@AAnoosheh no, #5359 just closed but couldn't be merged yet. We still need to work on it and figure out a solution.

kazimpal87 commented 2 years ago

Hi, I'm in a similar situation. My batches are formed from the output of an object detector, so sometimes the batch will essentially be of size zero (I can't think of a good way to explain this but just trust it makes sense). In this case, I would like to return None from train_step, or least return some kind of zero loss-tensor with zero gradients. If its not easy to return None, is there some way to artificially construct a zero-tensor with all the appropriate gradients present so that the DDP sync will work?

AsaphLightricks commented 2 years ago

Hi, is there a solution to this issue?

carmocca commented 2 years ago

No @AsaphLightricks. Returning None in traning_step with DDP is currently unsupported.

jhauret commented 2 years ago

I would need to return None times to times to implement this

yashpatel5400 commented 2 years ago

Also looking for a patch for this! (My use case is pretty much exactly what @kazimpal87 described above)

magehrig commented 2 years ago

I also need this feature. Is there a known workaround?

EricWiener commented 1 year ago

Hi, any updates on this?

popcornell commented 1 year ago

Second this, this is basically a must have feature. Even totally skipping any update across all workers when just one returns None will be okay for me. Right now the training freezes and thus lightning + DDP cannot be used for a lot of tasks e.g. automatic speech recognition (ASR) where this issue is common due to the loss used + mixed precision.

Can't it be done here ? basically skipping reduce altogether and returning None ? (not very familiar with Lightning sorry) https://github.com/Lightning-AI/lightning/blob/7268670d1aec7d40241962c5bc81b0b871de0a57/src/lightning/pytorch/strategies/ddp.py#LL314C5-L331C1

gianscarpe commented 1 year ago

Hi! I'll give this a look :)

popcornell commented 1 year ago

Great thanks, for now it can be solved with manual optimization BTW if someone has this problem.

On my side the training freezes not on the iteration when None is returned but the one right after, when calling in pytorch_lightning/loops/training_epoch_loop.py:

batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)

the iteration that returns None goes forward actually without a problem apparently and even updates the progress bar.

popcornell commented 1 year ago

I have a fix but it is inside the training loop ( it is similar to @tchaton suggestion above):

    def _reduce_loss(self, utt_id, batch_id, loss, reduction="mean"):

        assert loss.shape[0] == len(utt_id), "loss must be reduced to batch dimension !"

        mask_nan_inf = torch.logical_or(torch.isnan(loss), ~torch.isfinite(loss))
        if torch.any(mask_nan_inf):
            where_invalid = torch.where(mask_nan_inf)[0]
            for indx in range(where_invalid.shape[0]):
                inv_indx = where_invalid[indx].item()
                log.info(
                    f"NaN loss in batch {batch_id} of epoch {self.current_epoch}, for utt_id {utt_id[inv_indx]}"
                )
            # if any is invalid then we must flag this to all processes
            flag_skip = torch.ones((), device=loss.device, dtype=torch.bool)
        else:
            flag_skip = torch.zeros((), device=loss.device, dtype=torch.bool)

        # sub-optimal but will do,
        # till they fix it in https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1552650013
        world_size = torch_dist.get_world_size()
        torch_dist.barrier()
        # now gather
        result = [torch.zeros_like(flag_skip) for _ in range(world_size)]
        torch_dist.all_gather(result, flag_skip)
        any_invalid = torch.sum(torch.stack(result)).bool().item()

        if any_invalid:
            if self.nan_countdown >= 100:
                raise RuntimeError(
                    "Too many NaNs loss iterations encountered, stopping !"
                )
            self.nan_countdown += 1
            return None
        else:
            self.nan_countdown = 1
            return loss.mean() if reduction == "mean" else loss.sum()

Basically I gather a flag across all DDP workers, if any of the workers set the flag all workers must return None. If all return None there is not anymore freezing. But it would be neat if this stuff is handled inside lightning. I feel here I just add unnecessary synchronization.

I am sure for someone more familiar with lightning background magic it must be easy to do add something similar in the right place.

amorehead commented 1 year ago

@popcornell, thanks for sharing! This logic seems to work for OOM errors occurring in the forward pass. However, when I tried this logic similarly in Lightning's backward hook, I experienced the dreaded DDP freezing issue, suggesting that my DDP ranks have still fallen out of sync. My backward code is as follows. Does anyone here have any ideas about what might be causing the ranks to become out of sync with each other? Notably, my code never logs the warning "Ran out of memory in the backward pass. Skipping batches for all ranks.", which suggests that at least one of the ranks is never hitting the barrier (i.e., torch_dist.barrier()) once all the others have.

def backward(self, loss: torch.Tensor, *args: Any, **kwargs: Any):
        """Overrides Lightning's `backward` step to add an out-of-memory (OOM) check."""
        # by default, do not skip the current batch
        skip_flag = torch.zeros(
            (), device=self.device, dtype=torch.bool, requires_grad=False
        )  # NOTE: for skipping batches in a multi-device setting

        try:
            loss.backward(*args, **kwargs, retain_graph=False)
        except RuntimeError as e:
            skip_flag = torch.ones((), device=self.device, dtype=torch.bool, requires_grad=False)
            if "out of memory" in str(e):
                log.warning(
                    f"Ran out of memory in the backward pass, where `torch_dist.is_initialized` is {torch_dist.is_initialized()}. Skipping batch due to: {e}"
                )
                if not torch_dist.is_initialized():
                    # NOTE: for skipping batches in a single-device setting
                    del loss  # delete the computation graph
                    for p in self.net.parameters():
                        if p.grad is not None:
                            del p.grad

        # NOTE: for skipping batches in a multi-device setting
        # credit: https://github.com/Lightning-AI/lightning/issues/5243#issuecomment-1553404417
        if torch_dist.is_initialized():
            # if any rank skips a batch, then all other ranks need to skip
            # their batches as well so DDP can properly keep all ranks synced
            world_size = torch_dist.get_world_size()
            torch_dist.barrier()
            result = [torch.zeros_like(skip_flag) for _ in range(world_size)]
            torch_dist.all_gather(result, skip_flag)
            any_skipped = torch.sum(torch.stack(result)).bool().item()
            if any_skipped:
                del loss  # delete the computation graph
                for p in self.net.parameters():
                    if p.grad is not None:
                        del p.grad
                log.warning(
                    "Ran out of memory in the backward pass. Skipping batches for all ranks."
                )
amorehead commented 1 year ago

Following up on my previous comment, it seems like https://github.com/pytorch/pytorch/issues/18853#issuecomment-698386652 discusses a related issue. In this case, it may be the case that the OOM errors I am seeing are causing loss.backward to be successfully called on some variables of the total autograd graph (but not all of them, since some do not complete their loss.backward call). Said another way, is there a way with DDP to manually tell it to effectively mark certain autograd variables as "having completed their backward pass" even if they actually haven't? That way, DDP ranks would never become out of sync with each other.

awaelchli commented 1 year ago

Your code probably hangs because one process gets the exception and the other does not, and so the one that did not raise will hang at the barrier while watching the other die (sad). When you catch it and run some logic, you will need to synchronize this decision. Make it so that either all of them don't raise, or all of them raise, but a mix is not allowed.

amorehead commented 1 year ago

@awaelchli, thanks for your suggestion. I agree: if a DDP rank raises an exception (e.g., e in my code above), this kind of issue will definitely occur. However, what confuses me is my warning log log.warning(f"Ran out of memory in the backward pass. Skipping batch due to: {e}") gets printed to my respective log file right before my entire training job freezes. This implies that that specific DDP rank that OOM'd indeed went down the if branch of my exception handling code, which should not raise the underlying exception e. To check this though, I am now running a test of this same scenario where I remove the else: raise e branch of the exception handling logic to see if indeed some DDP ranks were raising the actual exception when they shouldn't have been. I'll report back with my findings.

amorehead commented 1 year ago

@awaelchli, sadly, removing the else: raise e branch in my exception handling logic above does not resolve the DDP rank freezing issue I am facing. What else might be causing one of the ranks to die before it reaches the torch_dist.barrier() call? I can confirm that I can see exactly one warning log being issued before my training job completely freezes across all ranks:

Ran out of memory in the backward pass, where `torch_dist.is_initialized` is True. Skipping batch due to: CUDA out of memory. Tried to allocate 5.18 GiB (GPU 0; 79.15 GiB total capacity; 72.98 GiB already allocated; 139.25 MiB free; 78.39 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.

Since torch_dist.is_initialized is confirmed to be set to True here, any DDP rank that logs such a message should not return None, thereby preventing such a rank from "dying early".

amorehead commented 1 year ago

Sadly, it looks like I may have found a culprit for this issue: wandb (using the latest version 0.15.12 - reference: https://github.com/pytorch/xla/issues/1562#issuecomment-579765003).

After disabling my WandbLogger temporarily and rerunning my model training script from the latest checkpoint (shortly after which my script would normally OOM during the loss.backward() call), I am now not seeing the original OOM error which would normally freeze all my DDP ranks from there on. I don't believe I am having wandb watch my model's computational graph in any way, unless similar behavior is enabled by default with Lightning's WandbLogger.

Has anyone else experienced this issue and found a way around it (besides perhaps setting offline=true for WandbLogger and manually syncing local logs to remote later on)? @sydholl has the wandb team seen any issues like this occurring before (https://github.com/wandb/wandb/issues/2091)?