Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.33k stars 3.38k forks source link

FSDP checkpointing uses deprecated APIs with PyTorch 2.2 #19462

Open carmocca opened 8 months ago

carmocca commented 8 months ago

Bug description

See added deprecation warnings in https://github.com/pytorch/pytorch/pull/113867

What version are you seeing the problem on?

v2.2

How to reproduce the bug

Originated from

https://github.com/Lightning-AI/pytorch-lightning/blob/b097a4df3f3fa8b4465861ccab17a44a8ae1ebb9/src/lightning/fabric/strategies/fsdp.py#L496

We already use the newer API for loading

https://github.com/Lightning-AI/pytorch-lightning/blob/b097a4df3f3fa8b4465861ccab17a44a8ae1ebb9/src/lightning/fabric/strategies/fsdp.py#L563-L566

Error messages and logs

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py:31: UserWarning: 'save_state_dict' is deprecated and will be removed in future versions.Please use 'save' instead.
  warnings.warn(

Environment

No response

More info

No response

cc @awaelchli @carmocca

carmocca commented 8 months ago

Two more which probably need to be fixed in PyTorch

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py:1132: UserWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  warnings.warn(DEPRECATE_MSG)

From (print_stack added by me):

  File "/home/carlos/stuff.py", line 29, in <module>
    fabric.save(f"{compile}_before_fwd", {"model": fmodel})
  File "/home/carlos/lightning/src/lightning/fabric/fabric.py", line 770, in save
    self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state), filter=filter)
  File "/home/carlos/lightning/src/lightning/fabric/strategies/fsdp.py", line 484, in save_checkpoint
    converted = obj.state_dict()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1922, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 737, in _post_state_dict_hook
    local_shape = tensor.shape
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/_shard/sharded_tensor/api.py", line 1134, in __torch_function__
    traceback.print_stack()

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py:151: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():

From (print_stack added by me):

  File "/home/carlos/stuff.py", line 29, in <module>
    fabric.save(f"{compile}_before_fwd", {"model": fmodel})
  File "/home/carlos/lightning/src/lightning/fabric/fabric.py", line 770, in save
    self._strategy.save_checkpoint(path=path, state=_unwrap_objects(state), filter=filter)
  File "/home/carlos/lightning/src/lightning/fabric/strategies/fsdp.py", line 496, in save_checkpoint
    save_state_dict(converted_state, writer)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 40, in save_state_dict
    return _save_state_dict(
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 280, in _save_state_dict
    return distW.all_reduce("write", write_data, finish_checkpoint)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py", line 210, in all_reduce
    local_data = map_fun()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 270, in write_data
    all_writes = storage_writer.write_data(final_local_plan, planner)
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 470, in write_data
    _write_files_from_queue(
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 284, in _write_files_from_queue
    loader.start_loading()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 179, in start_loading
    self._refill()
  File "/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/filesystem.py", line 150, in _refill
    traceback.print_stack()
carmocca commented 8 months ago

If the newer save is used, the argument order seems to have changed in https://github.com/pytorch/pytorch/pull/117772

/home/carlos/nightly-env/lib/python3.10/site-packages/torch/distributed/checkpoint/utils.py:409: UserWarning: The argument order of save has been changed. Please check the document to avoid future breakages.
  warnings.warn(

This probably applies to load too. I haven't tried it

awaelchli commented 8 months ago

I agree we need to update these imports. The change in argument order is only in nightly, but since lit-gpt relies on that, we should start incorporating this asap.

carmocca commented 8 months ago

Technically lit-gpt doesn't rely on nightly since the 2.2 release.

I opened #19463

carmocca commented 8 months ago

Also opened https://github.com/pytorch/pytorch/issues/119802 upstream. We might want to silence these after this is resolved

carmocca commented 8 months ago

https://github.com/pytorch/pytorch/issues/119800#issuecomment-1942156271 suggests that we should replace (in 2.2+) most of what we have with {get,set}_{model,optimizer}_state_dict functions in https://github.com/pytorch/pytorch/blob/v2.2.0/torch/distributed/checkpoint/state_dict.py