pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.59k stars 204 forks source link

[FSDP2 + TP + DCP] Error when checkpointing using TP+FSDP2+DCP #441

Closed thomasneff closed 3 months ago

thomasneff commented 4 months ago

Hi!

Amazing job on the example implementations of all of these cutting edge training features! When trying to run the following configuration though, I ran into issues:

Llama 3 8B
data_parallel_degree = 4
tensor_parallel_degree = 2
pipeline_parallel_degree = 1

Once it starts checkpointing, I get the following error:

raise NotImplementedError(
  NotImplementedError: 2D state_dict is under development. Please check https://github.com/pytorch/pytorch/issues/129627 for more details.

Checking https://github.com/pytorch/pytorch/issues/129627 , it's a bit unclear on what the status on fixing this is, and how a FSDP2+TP pipeline should be handled with DCP at this point in time / at the state of PyTorch 2.4.

The torchtitan README mentions TP/DCP/FSDP/PP all working, so I'd have expected this to work too, especially since there's also a 3D parallel test in test_runner.py (that presumably also doesn't work?)

If there's a workaround for this (other than just not using DCP), I'd love to know, otherwise it might be good to add some comments/mention in the README if this is known to be broken at the moment.

My env:

pytorch-triton==3.0.0+dedb7bdf33
torch==2.5.0.dev20240709+cu121
torchdata==0.7.1.dev20240709+cpu

Thanks!

awgu commented 4 months ago

@thomasneff We recently disabled FSDP2 + TP state dict because we need to implement something called strided sharding in order for the data layout in the state dict to be correct for DCP resharding. While that is being worked on, we wanted to make sure no one saved any state dicts for now to avoid having BC issues where the saved state dict could not be loaded again.

We are working on it! See https://github.com/pytorch/pytorch/issues/129627 for some more details.

thomasneff commented 4 months ago

Awesome, thanks for clarifying - that's what I thought, and makes sense. We'll work around it differently then, but I'm already waiting in excitement for when this eventually ships 🙂

mayank31398 commented 3 months ago

seems like even HSDP without model parallel is disabled now.

awgu commented 3 months ago

@mayank31398 thanks for pointing this out! We should fix this (and add some HSDP state dict save/load unit test).

loretoparisi commented 2 months ago

@awgu is this the reason behind

rank4]: Traceback (most recent call last):
[rank4]:   File "train.py", line 121, in <module>
[rank4]:     train()
[rank4]:   File "train.py", line 112, in train
[rank4]:     trainer.fit(model)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
[rank4]:     call._call_and_handle_interrupt(
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank4]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank4]:     return function(*args, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
[rank4]:     self._run(model, ckpt_path=ckpt_path)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
[rank4]:     results = self._run_stage()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
[rank4]:     self.fit_loop.run()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 206, in run
[rank4]:     self.on_advance_end()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 378, in on_advance_end
[rank4]:     call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 210, in _call_callback_hooks
[rank4]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 323, in on_train_epoch_end
[rank4]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 385, in _save_topk_checkpoint
[rank4]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 713, in _save_none_monitor_checkpoint
[rank4]:     self._save_checkpoint(trainer, filepath)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 388, in _save_checkpoint
[rank4]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1369, in save_checkpoint
[rank4]:     checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 436, in dump_checkpoint
[rank4]:     "state_dict": self._get_lightning_module_state_dict(),
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 497, in _get_lightning_module_state_dict
[rank4]:     return self.trainer.strategy.lightning_module_state_dict()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/model_parallel.py", line 262, in lightning_module_state_dict
[rank4]:     return get_model_state_dict(self.model, options=state_dict_options)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict.py", line 976, in get_model_state_dict
[rank4]:     model_state_dict = _get_model_state_dict(model, info)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict.py", line 466, in _get_model_state_dict
[rank4]:     state_dict = _state_dict_fn(model, "state_dict")()
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1939, in state_dict
[rank4]:     module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1939, in state_dict
[rank4]:     module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1935, in state_dict
[rank4]:     hook(self, prefix, keep_vars)
[rank4]:   File "/home/coder/.local/lib/python3.8/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 154, in _raise_not_implemented_if_2d
[rank4]:     raise NotImplementedError(
[rank4]: NotImplementedError: 2D state_dict is under development. Please check https://github.com/pytorch/pytorch/issues/129627 for more details.

I'm running the example provided for Llama3 - tensor parallel here https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/pytorch/tensor_parallel

So how can make it works? Thanks

awgu commented 2 months ago

@loretoparisi I think the latest nightlies should have this fix now!

mayank31398 commented 2 months ago

yes, I have tested on latest nightly. not with torchtitan but with a custom repository. it works fine now :)