facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.17k stars 277 forks source link

Can exclude some layer parameter not to shard? #1123

Open robotcator opened 1 year ago

robotcator commented 1 year ago

This default_auto_wrap_policy function has parameter exclude_wrap_modules for excluding module types in wrapping. Is that mean that the module's parameter will not shard? how can I check if it works or not? @min-xu-ai

min-xu-ai commented 1 year ago

Thanks for the question and tagging me.

No, the params will still be sharded by the outer FSDP wrapper, just excluded by the auto_wrap algorithm to determine the nested wrapping structure. The actual sharding config is determined by the wrapper's process_group argument. If the process_group contains only a single GPU, then it is not sharded.

To check the wrapping structure, you can simply print() out the model and examine where are the FSDP wrappers inserted.

robotcator commented 1 year ago

@min-xu-ai Thank you for your kind reply. Is that mean the FSDP wrap module will flatten all model parameters and shard on each rank? If so, what's the behavior of the inner module with the FSDP wrapper will be?

For the 'To check the wrapping structure, you can simply print() out the model and examine where are the FSDP wrappers inserted.', is there any API to check the specific shard module parameter, like the FSDP wrapper layer has params to check the size of each rank.

robotcator commented 1 year ago

@min-xu-ai There is another question about the FSDP wrapper in a single GPU, on the save checkpoint stage, there raise a exception. Is there any idea how to handle this error?

 File "/opt/conda/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 2400, in gather_full_optim_state_dict
    state, singleton_state = self._gather_optim_state(sd.pop("state"))
  File "/opt/conda/lib/python3.8/site-packages/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 2344, in _gather_optim_state
    desired_buffer_size = non_shared_params[0]._full_param_padded.size()
AttributeError: 'FlatParameter' object has no attribute '_full_param_padded'
min-xu-ai commented 1 year ago

Is that mean the FSDP wrap module will flatten all model parameters and shard on each rank? If so, what's the behavior of the inner module with the FSDP wrapper will be?

Both inner and outer wrappers do the same sharding based on the process group it is given. They just own different setup of params based on which param is wrapped by which wrapper. Flatten or not it a separate argument to the wrappers.

is there any API to check the specific shard module parameter

I don't think there is any API for this, but you can inspect them directly as long as you can access the wrapper object.

Is there any idea how to handle this error?

No idea. This might be a corner case bug. maybe you can try pytorch's version of FSDP. See if that one work better for you or not.

robotcator commented 1 year ago

Is that mean the FSDP wrap module will flatten all model parameters and shard on each rank? If so, what's the behavior of the inner module with the FSDP wrapper will be?

Both inner and outer wrappers do the same sharding based on the process group it is given. They just own different setup of params based on which param is wrapped by which wrapper. Flatten or not it a separate argument to the wrappers.

is there any API to check the specific shard module parameter

I don't think there is any API for this, but you can inspect them directly as long as you can access the wrapper object.

Is there any idea how to handle this error?

No idea. This might be a corner case bug. maybe you can try pytorch's version of FSDP. See if that one work better for you or not.

Thank you for your kind reply, I got it.