Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.16k stars 3.37k forks source link

Lightning creates two DeepSpeedEngine instances for the same model #17523

Closed HeyangQin closed 1 year ago

HeyangQin commented 1 year ago

Bug description

Hello Lightning team!

We got serval user reports (e.g. https://github.com/microsoft/DeepSpeed/issues/3068) about errors when using Lighning with DeepSpeed. The issue is that Lightning creates two DeepSpeedEngine instances for the same model at https://github.com/Lightning-AI/lightning/blob/6ec9a6bd9e792f505ebc931742d4235f311eb289/src/lightning/pytorch/strategies/deepspeed.py#L447-L450 Yet neither of the DeepSpeedEngine is aware of the existence of the other. So when it comes to zero3 optimization, these two DeepSpeedEngines are going to operate on the same set of parameters on their own management, which leads to the crash. We tried to tackle this issue from our end by bounding the parameters management to the model so they can be shared among DeepSpeedEngine instances, yet we realize the Lightning creates different wrapper instances for the model before passing it to DeepSpeed so from the DeepSpeed end it looks like different models. DeepSpeed can do both training and validation on the same DeepSpeedEngine instance. Thus we want to reach out to understand more about the intuition behind using multiple DeepSpeedEngines (or wrappers) and also to check if there is anything we can do on our end to make the same DeepSpeedEngine usable for both the training and validation in your use case.

What version are you seeing the problem on?

master

How to reproduce the bug

There is a pretty nice reproduction script from the user https://github.com/microsoft/DeepSpeed/issues/3068#issuecomment-1486539136

Error messages and logs

No response

Environment

No response

More info

No response

cc @awaelchli

awaelchli commented 1 year ago

Hello @HeyangQin Thanks for providing the reproducible script. I was able to reproduce the error reported on the linked issue. However, if I look at the model at the lines you pointed out above, none of the submodules is wrapped in a deepspeed engine. So calling validate() and then fit() always gets the pure LightningModule as input:

This holds at line 447:

assert not [name for name, mod in self.model.named_modules() if isinstance(mod, deepspeed.DeepSpeedEngine)]

So I think the problem is not that there are multiple DeepSpeed engines, but rather that the first call to trainer shards the model, then the user passes the (unwrapped) LightningModule again into the next Trainer.fit() call. When doing so, something seems to go wrong.

In pseudo steps, I think this is what happens:

  1. model = MyLightningModule()
  2. trainer.validate(model) wraps model in engine = DeepSpeedEngine(model)
  3. trainer runs with engine (zero 3), parameters are sharded
  4. user takes original (unwrapped) model (LightningModule) whose parameters are sharded already and passes it again into trainer.fit()
  5. trainer.fit(model) again initializes deepspeed engine = DeepSpeedEngine(model) but this time, original model already has parameters sharded (I think?)

Can I ask you, conceptually what does this error mean?

assert param.ds_status == ZeroParamStatus.AVAILABLE, param.ds_summary()
AssertionError: {'id': 0, 'status': 'INFLIGHT', 'numel': 23440896, 'ds_numel': 23440896, 'shape': (30522, 768), 'ds_shape': (30522, 768), 'requires_grad': True, 'grad_shape': None, 'persist': False, 'active_sub_modules': {233}}

what does inflight mean here?

HeyangQin commented 1 year ago

Hello @awaelchli. Thank you for the investigation and detailed explanation! Yes. I think the pseudo steps you outlined are precisely how the error is triggered.

To give more context about the error message, in zero3 optimization all the model weights will be partitioned (sharded) to all GPUs. When a certain weight is needed for computation, the DeepSpeedEngine is going to fetch it from all GPUs. inflight means this parameter is currently being fetched (fetching started but not necessarily completed). Due to the prefetch mechanism, there could be multiple parameters in the inflight status at the same time.

At runtime, trainer.validate and trainer.fit each create a DeepSpeedEngine respectively. Thus when DeepSpeedEngine A put some parameters inflight, DeepSpeedEngine B will get confused at why there are parameters that it never fetched inflight and error out.

We cannot simply have the parameters status shared by all DeepSpeedEngines because it would stop DeepSpeed from working when there are multiple models at the same time (e.g. RLHF). A straightforward solution from the DeepSpeed side would be: bounding the parameters management to the model so they can be shared among DeepSpeedEngine instances for the same model, yet the Lightning creates new wrapper instance for the model before passing it to DeepSpeed so from the DeepSpeed end it looks like different models. I wonder if it is possible to

  1. Either implement a check on the Lightning side to see if DeepSpeedEngine has been initialized for the model to reuse existing DeepSpeedEngine to prevent reinstantiation.
  2. Or implement a reuse mechanism for Lightning wrapper instance so DeepSpeed can correctly recognize this is the same model. We will update the DeepSpeed framework to make parameter status tied to the model.

I think either way would fix this issue

awaelchli commented 1 year ago

@HeyangQin I think I understand now, thanks for the explanation. So if I understand correctly, if this special wrapper https://github.com/Lightning-AI/lightning/blob/f6af74bf158a064673b5db284490b0da1f6c6852/src/lightning/pytorch/strategies/deepspeed.py#L445

wouldn't exist, and the raw LightningModule is passed to the deepspeed.initialize(), then deepspeed could detect that this was the same model already used before and it would work properly, correct? And are saying that this detection is not yet implemented.

I might know a way to avoid the wrapper in Lightning, but I have to verify that it is possible. Given that this could be a solution, the with the changes both in Lightning and deepspeed, we could resolve this limitation. So I see solution 2 being the most flexible, if possible.

HeyangQin commented 1 year ago

Hello @awaelchli. Yes. We have implemented a patch (not merged into master yet) so that if multiple DeepSpeedEngine instances work on the same model, they will share the parameter management. For the model itself, it doesn't matter if it is a wrapper or the raw model, as long as deepspeed.initialize() gets the same model, the patch will work.

cc my colleague @tjruwase

awaelchli commented 1 year ago

Sounds good. I'll work on a draft so I can verify your patch together with the modifications in Lightning.

awaelchli commented 1 year ago

@HeyangQin Is this the branch for the patch you mentioned? https://github.com/microsoft/DeepSpeed/pull/3380

awaelchli commented 1 year ago

@HeyangQin In my branch at https://github.com/Lightning-AI/lightning/pull/17531 I'm still seeing the inflight parameter error in test cases where for example Trainer.test() and Trainer.fit() are called in sequence.

    self.epoch_loop.run(self._data_fetcher)
../../src/pytorch_lightning/loops/training_epoch_loop.py:137: in run
    self.on_advance_end()
../../src/pytorch_lightning/loops/training_epoch_loop.py:252: in on_advance_end
    self.val_loop.run()
../../src/pytorch_lightning/loops/utilities.py:177: in _decorator
    return loop_run(self, *args, **kwargs)
../../src/pytorch_lightning/loops/evaluation_loop.py:147: in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
../../src/pytorch_lightning/loops/evaluation_loop.py:401: in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
../../src/pytorch_lightning/trainer/call.py:287: in _call_strategy_hook
    output = fn(*args, **kwargs)
../../src/pytorch_lightning/strategies/strategy.py:380: in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
../../src/pytorch_lightning/strategies/strategy.py:563: in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501: in _call_impl
    return forward_call(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py:15: in wrapped_fn
    ret_val = func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/engine.py:1724: in forward
    loss = self.module(*inputs, **kwargs)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1538: in _call_impl
    result = forward_call(*args, **kwargs)
../../src/pytorch_lightning/strategies/strategy.py:556: in wrapped_forward
    out = method(*_args, **_kwargs)
../../src/pytorch_lightning/demos/boring_classes.py:131: in validation_step
    return {"x": self.step(batch)}
../../src/pytorch_lightning/demos/boring_classes.py:124: in step
    output = self(batch)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1547: in _call_impl
    hook_result = hook(self, args, result)
/usr/local/lib/python3.10/dist-packages/deepspeed/utils/nvtx.py:15: in wrapped_fn
    ret_val = func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/parameter_offload.py:329: in _end_of_forward_hook
    self.get_param_coordinator(training=False).reset_step()
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <deepspeed.runtime.zero.partitioned_param_coordinator.PartitionedParameterCoordinator object at 0x7feb792bc5b0>

    def reset_step(self) -> None:
        """indicate that we have completed one fwd+bwd for the model"""
        if self.__inflight_param_registry:
>           raise RuntimeError(f"still have inflight params "
                               f"{[p.ds_summary for p in self.__inflight_param_registry.keys()]}")
E           RuntimeError: still have inflight params [<bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
E           tensor([], dtype=torch.float16, requires_grad=True)>, <bound method Init._convert_to_deepspeed_param.<locals>.ds_summary of Parameter containing:
E           tensor([], dtype=torch.float16, requires_grad=True)>]

/usr/local/lib/python3.10/dist-packages/deepspeed/runtime/zero/partitioned_param_coordinator.py:183: RuntimeError
=========================== short test summary info ============================
FAILED strategies/test_deepspeed_strategy.py::test_deepspeed_multigpu_stage_3 - RuntimeError: still have inflight params [<bound method Init._convert_to_de...
======================= 1 failed, 15 warnings in 11.30s ========================

This is with installing deepspeed from source from your branch in https://github.com/microsoft/DeepSpeed/pull/3462. In my branch, I have removed the wrapper around the model completely, so the deepspeed initialize will see the same model.

When I print the registry

trainer.test(model)
print(model.ds_inflight_param_registry)
trainer.fit(model)

I get an empty dict. Any ideas what could be wrong?

HeyangQin commented 1 year ago

Hi @awaelchli, Thanks for the update! My bandwidth is a bit limited for now. Let me look into this and get back to you

HeyangQin commented 1 year ago

Hello @awaelchli, I tried your branch with https://github.com/microsoft/DeepSpeed/pull/3462 with the user script and the error is gone on my end. Could you share the test script you use where you saw this error? Thank you!

awaelchli commented 1 year ago

@HeyangQin I found and fixed the problem with our test case, it was on our end (#17625). Now the stage 3 test cases pass on my branch with your branch together and I can finish it up. Feel free to go ahead with https://github.com/microsoft/DeepSpeed/pull/3462, great work!

HeyangQin commented 1 year ago

@awaelchli Thank you! It is a great collabration effort and we really appreciate your quick response and great work. We have merged https://github.com/microsoft/DeepSpeed/pull/3462.