pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.5k stars 151 forks source link

Cannot run FSDP2 with low bit optim from AO #1189

Open nighting0le01 opened 1 day ago

nighting0le01 commented 1 day ago

Cannot run FSDP2 with low bit optim from AO

[rank7]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank7]:   File "<frozen runpy>", line 88, in _run_code
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 226, in <module>
[rank7]:     main()
[rank7]:   File "/nfs/asahni/multi_parallel/oct_28/training/scripts/train.py", line 218, in main
[rank7]:     trainer.fit(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank7]:     return trainer_fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank7]:     self._run(model, ckpt_path=ckpt_path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank7]:     results = self._run_stage()
[rank7]:               ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank7]:     self.fit_loop.run()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank7]:     self.advance()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank7]:     self.epoch_loop.run(self._data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank7]:     self.advance(data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 269, in advance
[rank7]:     call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 218, in _call_callback_hooks
[rank7]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 316, in on_train_batch_end
[rank7]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank7]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 715, in _save_none_monitor_checkpoint
[rank7]:     self._save_checkpoint(trainer, filepath)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
[rank7]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1365, in save_checkpoint
[rank7]:     self.strategy.save_checkpoint(checkpoint, filepath, storage_options=storage_options)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/model_parallel.py", line 321, in save_checkpoint
[rank7]:     _distributed_checkpoint_save(converted_state, path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/fabric/strategies/fsdp.py", line 867, in _distributed_checkpoint_save
[rank7]:     save(converted_state, checkpoint_id=path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 429, in inner_func
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 152, in save
[rank7]:     return _save_state_dict(
[rank7]:            ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 316, in _save_state_dict
[rank7]:     central_plan: SavePlan = distW.reduce_scatter("plan", local_step, global_step)
[rank7]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 191, in reduce_scatter
[rank7]:     raise result
[rank7]: torch.distributed.checkpoint.api.CheckpointException: CheckpointException ranks:dict_keys([0, 1, 2, 3, 4, 5, 6, 7])
[rank7]: Traceback (most recent call last): (RANK 0)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/utils.py", line 164, in reduce_scatter
[rank7]:     local_data = map_fun()
[rank7]:                  ^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/logger.py", line 83, in wrapper
[rank7]:     result = func(*args, **kwargs)
[rank7]:              ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict_saver.py", line 303, in local_step
[rank7]:     local_plan = planner.create_local_plan()
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 101, in create_local_plan
[rank7]:     plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/default_planner.py", line 399, in create_default_local_save_plan
[rank7]:     requests += _create_write_items(fqn, obj)
[rank7]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 222, in _create_write_items
[rank7]:     return object.__create_write_items__(fqn, object)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 598, in __create_write_items__
[rank7]:     return [_create_write_items_for_dtensor(fqn, object)]
[rank7]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/planner_helpers.py", line 86, in _create_write_items_for_dtensor
[rank7]:     properties=TensorProperties.create_from_tensor(tensor.to_local()),
[rank7]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/checkpoint/metadata.py", line 108, in create_from_tensor
[rank7]:     pin_memory=tensor.is_pinned(),
[rank7]:                ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 377, in _dispatch__torch_function__
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torchao/utils.py", line 393, in _dispatch__torch_dispatch__
[rank7]:     raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")
[rank7]: NotImplementedError: OptimState8bit dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.is_pinned', overload='default')>, types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), arg_types=(<class 'torchao.prototype.low_bit_optim.subclass_8bit.OptimState8bit'>,), kwarg_types={}
nighting0le01 commented 1 day ago

@gau-nernst can you please take a look?: is_pinned is not implemented hence it causes issues when saving optimizer states

gau-nernst commented 1 day ago

Do you have a small reproduction? Yea I think we don't exactly test for saving/loading optimizers in FSDP2. Will add tests for it.

nighting0le01 commented 12 hours ago

unfortunately i cannot share the orignal code, but this ability to save optimizer states properly with FSDP2 and other parallelism implemented with Dtensor is crucial to make use of these low bit optimizers. @gau-nernst

gau-nernst commented 8 hours ago

It looks like you are using pytorch lightning, and it calls torch.distributed.checkpoint.state_dict_saver.save(). This function requires the tensor subclass to implement aten.is_pinned.

I tested that the normal torch.save(optim.state_dict(), "state_dict.ckpt") works fine. @nighting0le01 In the mean time, is it possible for you to switch to the plain torch.save() for checkpointing?

@awgu What are the benefits of using torch.distributed.checkpoint.state_dict_saver.save() over the plain torch.save()? From my understanding of https://pytorch.org/docs/stable/distributed.checkpoint.html, it seems like the former will handle some kind of resharding when loading? Is the saving the same?

Implementing aten.is_pinned op is simple, but since I'm not too familiar with torch.distributed.checkpoint, what is the recommended/correct way to save and load (sharded) optim state dict with it? (so that I can add the correct tests) Is it something like this

from torch.distributed.checkpoint import state_dict_saver, state_dict_loader

fsdp_model = ...
fsdp_optim = AdamW8bit(fsdp_model.parameters())

# do some training, so optimizer states are initialized

rank = torch.distributed.get_rank()
state_dict_saver.save(fsdp_optim.state_dict(), checkpoint_id=f"state_dict_rank{rank}.ckpt")

# new sharded optim. optimizer states are not initialized
new_fsdp_optim = AdamW8bit(fsdp_model.parameters())
state_dict = new_fsdp_optim.state_dict()

# this requires aten.detach. and it doesn't seem to load optim state when the new optim state is empty (i.e. not initialized)
state_dict_saver.load(state_dict, checkpoint_id=f"state_dict_rank{rank}.ckpt")
new_fsdp_optim.load(state_dict)