Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.26k stars 3.38k forks source link

`strict = False` does not work when the checkpoint is distributed #20274

Open NathanGodey opened 1 month ago

NathanGodey commented 1 month ago

Bug description

When loading a sharded checkpoint with:

fabric.load(ckpt_path, state, strict = False)

the _distributed_checkpoint_load function called in the FSDPStrategy will raise an error if a checkpoint misses a key from the model in state, which should not be the case as strict = False.

A fix could be to take advantage of the DefaultLoadPlanner in torch.distributed.checkpoint.load, setting the allow_partial_load argument to the opposite of strict.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

No response

Error messages and logs

[rank7]: Traceback (most recent call last):
[rank7]:   File "my_codebase/train_fabric.py", line 226, in <module>
[rank7]:     main(**vars(args))
[rank7]:   File "my_codebase/train_fabric.py", line 148, in main
[rank7]:     fabric.load(ckpt_path, state, strict = strict_mode)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/fabric.py", line 773, in load
[rank7]:     remainder = self._strategy.load_checkpoint(path=path, state=unwrapped_state, strict=strict)
[rank7]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 570, in load_checkpoint
[rank7]:     _distributed_checkpoint_load(module_state, path)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 886, in _distributed_checkpoint_load
[rank7]:     load(module_state, checkpoint_id=path)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 434, in inner_func
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 168, in load
[rank7]:     _load_state_dict(
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 220, in _load_state_dict
[rank7]:     central_plan: LoadPlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 192, in reduce_scatter
[rank7]:     raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]:     return create_default_local_load_plan(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.lm_model.lm_head.weight.
[rank7]: Traceback (most recent call last): (RANK 1)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]:     return create_default_local_load_plan(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.lm_model.lm_head.weight.
[rank7]: Traceback (most recent call last): (RANK 2)
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 165, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 66, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_loader.py", line 209, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 197, in create_local_plan
[rank7]:     return create_default_local_load_plan(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "my_codebase/my-venv/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 316, in create_default_local_load_plan
[rank7]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank7]: RuntimeError: Missing key in checkpoint state_dict: model.my_key.

Environment

Current environment ``` #- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0 #- PyTorch Version (e.g., 2.4): 2.4.0+rocm6.0 #- Python version (e.g., 3.12): 3.11

More info

No response

nocoding03 commented 2 weeks ago

Maybe you can try as follows: 1.Adjust the strict parameter: Modify the code so that it sets allow_partial_load to not strict when calling torch.distributed.checkpoint.load, ensuring that partial loading is allowed in the case of strict=False. or 2'Modify the loading function: You can try to directly modify the loading logic in the fabric to pass the allow_partial_load, for example: def load_checkpoint_with_partial_support(path, state, strict=False):

Use the DefaultLoadPlanner with allow_partial_load set to the opposite of strict

load_planner = DefaultLoadPlanner(allow_partial_load=not strict)
load(module_state, checkpoint_id=path, planner=load_planner)