Closed thomasneff closed 3 months ago
@thomasneff We recently disabled FSDP2 + TP state dict because we need to implement something called strided sharding in order for the data layout in the state dict to be correct for DCP resharding. While that is being worked on, we wanted to make sure no one saved any state dicts for now to avoid having BC issues where the saved state dict could not be loaded again.
We are working on it! See https://github.com/pytorch/pytorch/issues/129627 for some more details.
Awesome, thanks for clarifying - that's what I thought, and makes sense. We'll work around it differently then, but I'm already waiting in excitement for when this eventually ships 🙂
seems like even HSDP without model parallel is disabled now.
@mayank31398 thanks for pointing this out! We should fix this (and add some HSDP state dict save/load unit test).
@awgu is this the reason behind
rank4]: Traceback (most recent call last):
[rank4]: File "train.py", line 121, in <module>
[rank4]: train()
[rank4]: File "train.py", line 112, in train
[rank4]: trainer.fit(model)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
[rank4]: call._call_and_handle_interrupt(
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank4]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank4]: return function(*args, **kwargs)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
[rank4]: self._run(model, ckpt_path=ckpt_path)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
[rank4]: results = self._run_stage()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1030, in _run_stage
[rank4]: self.fit_loop.run()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 206, in run
[rank4]: self.on_advance_end()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/loops/fit_loop.py", line 378, in on_advance_end
[rank4]: call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/call.py", line 210, in _call_callback_hooks
[rank4]: fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 323, in on_train_epoch_end
[rank4]: self._save_topk_checkpoint(trainer, monitor_candidates)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 385, in _save_topk_checkpoint
[rank4]: self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 713, in _save_none_monitor_checkpoint
[rank4]: self._save_checkpoint(trainer, filepath)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 388, in _save_checkpoint
[rank4]: trainer.save_checkpoint(filepath, self.save_weights_only)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/trainer.py", line 1369, in save_checkpoint
[rank4]: checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 436, in dump_checkpoint
[rank4]: "state_dict": self._get_lightning_module_state_dict(),
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 497, in _get_lightning_module_state_dict
[rank4]: return self.trainer.strategy.lightning_module_state_dict()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/lightning/pytorch/strategies/model_parallel.py", line 262, in lightning_module_state_dict
[rank4]: return get_model_state_dict(self.model, options=state_dict_options)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict.py", line 976, in get_model_state_dict
[rank4]: model_state_dict = _get_model_state_dict(model, info)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict.py", line 466, in _get_model_state_dict
[rank4]: state_dict = _state_dict_fn(model, "state_dict")()
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1939, in state_dict
[rank4]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1939, in state_dict
[rank4]: module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1935, in state_dict
[rank4]: hook(self, prefix, keep_vars)
[rank4]: File "/home/coder/.local/lib/python3.8/site-packages/torch/distributed/_composable/fsdp/_fsdp_param_group.py", line 154, in _raise_not_implemented_if_2d
[rank4]: raise NotImplementedError(
[rank4]: NotImplementedError: 2D state_dict is under development. Please check https://github.com/pytorch/pytorch/issues/129627 for more details.
I'm running the example provided for Llama3 - tensor parallel here https://github.com/Lightning-AI/pytorch-lightning/tree/master/examples/pytorch/tensor_parallel
So how can make it works? Thanks
@loretoparisi I think the latest nightlies should have this fix now!
yes, I have tested on latest nightly. not with torchtitan but with a custom repository. it works fine now :)
Hi!
Amazing job on the example implementations of all of these cutting edge training features! When trying to run the following configuration though, I ran into issues:
Once it starts checkpointing, I get the following error:
Checking https://github.com/pytorch/pytorch/issues/129627 , it's a bit unclear on what the status on fixing this is, and how a FSDP2+TP pipeline should be handled with DCP at this point in time / at the state of PyTorch 2.4.
The
torchtitan
README mentions TP/DCP/FSDP/PP all working, so I'd have expected this to work too, especially since there's also a 3D parallel test in test_runner.py (that presumably also doesn't work?)If there's a workaround for this (other than just not using DCP), I'd love to know, otherwise it might be good to add some comments/mention in the README if this is known to be broken at the moment.
My env:
Thanks!