Open NathanGodey opened 1 month ago
Maybe you can try as follows: 1.Adjust the strict parameter: Modify the code so that it sets allow_partial_load to not strict when calling torch.distributed.checkpoint.load, ensuring that partial loading is allowed in the case of strict=False. or 2'Modify the loading function: You can try to directly modify the loading logic in the fabric to pass the allow_partial_load, for example: def load_checkpoint_with_partial_support(path, state, strict=False):
load_planner = DefaultLoadPlanner(allow_partial_load=not strict)
load(module_state, checkpoint_id=path, planner=load_planner)
Bug description
When loading a sharded checkpoint with:
the
_distributed_checkpoint_load
function called in theFSDPStrategy
will raise an error if a checkpoint misses a key from the model instate
, which should not be the case asstrict = False
.A fix could be to take advantage of the DefaultLoadPlanner in
torch.distributed.checkpoint.load
, setting theallow_partial_load
argument to the opposite ofstrict
.What version are you seeing the problem on?
v2.4
How to reproduce the bug
No response
Error messages and logs
Environment
Current environment
``` #- PyTorch Lightning Version (e.g., 2.4.0): 2.4.0 #- PyTorch Version (e.g., 2.4): 2.4.0+rocm6.0 #- Python version (e.g., 3.12): 3.11More info
No response