pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.58k stars 174 forks source link

Unable to save checkpoints when Use low bit optimizers with FSDP1 or FSDP2 #1185

Open nighting0le01 opened 2 weeks ago

nighting0le01 commented 2 weeks ago

only occur when using 8 bit adam

with FSDP1 i run into:

FSDP config param_dtype: bf16 reduce_dtype: fp32

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/nfs/asahni/parallel_expts/training/scripts/train.py", line 226, in <module>
    main()
[rank7]: Traceback (most recent call last):
[rank7]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank7]:   File "<frozen runpy>", line 88, in _run_code
[rank7]:   File "/nfs/asahni/parallel_expts/training/scripts/train.py", line 226, in <module>
[rank7]:     main()
[rank7]:   File "/nfs/asahni/parallel_expts/training/scripts/train.py", line 218, in main
[rank7]:     trainer.fit(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
[rank7]:     call._call_and_handle_interrupt(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
[rank7]:     return trainer_fn(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
[rank7]:     self._run(model, ckpt_path=ckpt_path)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
[rank7]:     results = self._run_stage()
[rank7]:               ^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
[rank7]:     self.fit_loop.run()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
[rank7]:     self.advance()
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
[rank7]:     self.epoch_loop.run(self._data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
[rank7]:     self.advance(data_fetcher)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 269, in advance
[rank7]:     call._call_callback_hooks(trainer, "on_train_batch_end", batch_output, batch, batch_idx)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 218, in _call_callback_hooks
[rank7]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 316, in on_train_batch_end
[rank7]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank7]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 715, in _save_none_monitor_checkpoint
[rank7]:     self._save_checkpoint(trainer, filepath)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 390, in _save_checkpoint
[rank7]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1364, in save_checkpoint
[rank7]:     checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 447, in dump_checkpoint
[rank7]:     optimizer_state = trainer.strategy.optimizer_state(optimizer)
[rank7]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/lightning/pytorch/strategies/fsdp.py", line 539, in optimizer_state
[rank7]:     state_dict = FSDP.optim_state_dict(self.model, optimizer)
[rank7]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1890, in optim_state_dict
[rank7]:     return FullyShardedDataParallel._optim_state_dict_impl(
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1301, in _optim_state_dict_impl
[rank7]:     return _optim_state_dict(
[rank7]:            ^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank7]:     return func(*args, **kwargs)
[rank7]:            ^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1976, in _optim_state_dict
[rank7]:     fsdp_osd_state = convert_fn(
[rank7]:                      ^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1799, in _convert_state_with_orig_params
[rank7]:     _gather_all_orig_param_state(
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1693, in _gather_all_orig_param_state
[rank7]:     output_states = _allgather_orig_param_states(
[rank7]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1523, in _allgather_orig_param_states
[rank7]:     dtype, state_buffers = _convert_all_state_info(
[rank7]:                            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank7]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1382, in _convert_all_state_info
[rank7]:     assert dtype == info.dtype
[rank7]:            ^^^^^^^^^^^^^^^^^^^
gau-nernst commented 2 weeks ago

Is this duplicate of #1189? Can I close this and we can discuss over at the other issue?

nighting0le01 commented 2 weeks ago

hi @gau-nernst this is not exactly duplicate, it is for FSDP1. but we can shift there also if you prefer

gau-nernst commented 2 weeks ago

I don't think we actively support FSDP1. If you can create a minimal reproducible example, I can look into it. The errors seem different from those in #1189