huggingface / accelerate

🚀 A simple way to launch, train, and use PyTorch models on almost any device and distributed configuration, automatic mixed precision (including fp8), and easy-to-configure FSDP and DeepSpeed support
https://huggingface.co/docs/accelerate
Apache License 2.0
7.8k stars 946 forks source link

Setting FSDP FULL_STATE_DICT explicitly doesn't work #3072

Closed thepowerfuldeez closed 1 month ago

thepowerfuldeez commented 1 month ago

System Info

accelerate==0.34.0

Information

Tasks

Reproduction

trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

(trying to set FULL_STATE_DICT from fsdp config doesn't work either)

Expected behavior

trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")works

muellerzr commented 1 month ago

@thepowerfuldeez is there a reason why you're trying to dynamically change the state dict type? We should only be relying on how you instantiate the FSDPPlugin here, which is why you found this change.

(And also manually calling set_state_dict_type?)

thepowerfuldeez commented 1 month ago

@muellerzr I remember cpu offload didn't work with sharded state dict, so it's changed dynamically in the code. I tried removing this part of code and changing from the config, but it doesn't work either.

muellerzr commented 1 month ago

If we're doing it this way, you would need to first override it manually on the plugin, then call set_state_dict_type(). But some code is really needed for me to wrap around this workflow.