Open robotcator opened 1 year ago
Thanks for the question and tagging me.
No, the params will still be sharded by the outer FSDP wrapper, just excluded by the auto_wrap algorithm to determine the nested wrapping structure. The actual sharding config is determined by the wrapper's process_group argument. If the process_group contains only a single GPU, then it is not sharded.
To check the wrapping structure, you can simply print() out the model and examine where are the FSDP wrappers inserted.
@min-xu-ai Thank you for your kind reply. Is that mean the FSDP wrap module will flatten all model parameters and shard on each rank? If so, what's the behavior of the inner module with the FSDP wrapper will be?
For the 'To check the wrapping structure, you can simply print() out the model and examine where are the FSDP wrappers inserted.', is there any API to check the specific shard module parameter, like the FSDP wrapper layer has params
to check the size of each rank.
@min-xu-ai There is another question about the FSDP wrapper in a single GPU, on the save checkpoint
stage, there raise a exception. Is there any idea how to handle this error?
File "/opt/conda/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 2400, in gather_full_optim_state_dict
state, singleton_state = self._gather_optim_state(sd.pop("state"))
File "/opt/conda/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 2344, in _gather_optim_state
desired_buffer_size = non_shared_params[0]._full_param_padded.size()
AttributeError: 'FlatParameter' object has no attribute '_full_param_padded'
Is that mean the FSDP wrap module will flatten all model parameters and shard on each rank? If so, what's the behavior of the inner module with the FSDP wrapper will be?
Both inner and outer wrappers do the same sharding based on the process group it is given. They just own different setup of params based on which param is wrapped by which wrapper. Flatten or not it a separate argument to the wrappers.
is there any API to check the specific shard module parameter
I don't think there is any API for this, but you can inspect them directly as long as you can access the wrapper object.
Is there any idea how to handle this error?
No idea. This might be a corner case bug. maybe you can try pytorch's version of FSDP. See if that one work better for you or not.
Is that mean the FSDP wrap module will flatten all model parameters and shard on each rank? If so, what's the behavior of the inner module with the FSDP wrapper will be?
Both inner and outer wrappers do the same sharding based on the process group it is given. They just own different setup of params based on which param is wrapped by which wrapper. Flatten or not it a separate argument to the wrappers.
is there any API to check the specific shard module parameter
I don't think there is any API for this, but you can inspect them directly as long as you can access the wrapper object.
Is there any idea how to handle this error?
No idea. This might be a corner case bug. maybe you can try pytorch's version of FSDP. See if that one work better for you or not.
Thank you for your kind reply, I got it.
This
default_auto_wrap_policy
function has parameterexclude_wrap_modules
for excluding module types in wrapping. Is that mean that the module's parameter will not shard? how can I check if it works or not? @min-xu-ai