Open npuichigo opened 2 months ago
I think the issue here is you're not setting an auto_wrap_policy
, which is required? cc @pacman100 to confirm
auto_wrap_policy
may be another way. But I find no way to setup process_group and device_mesh which is in the interface of FSDP
torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)
True, we need to enable this in the FullyShardedDataParallelPlugin
, and make it work nicely when things are set via accelerate launch
. (In which, this would not be a param capable of being set there, for obvious reasons)
Sounds like one implication here is that transformer Trainer's --fsdp hybrid_shard
/ --fsdp hybrid_shard_zero2
can't possibly run successfully, since process_group
is not amoung trainer arg
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
Information
Tasks
no_trainer
script in theexamples
folder of thetransformers
repo (such asrun_no_trainer_glue.py
)Reproduction
Specify fsdp strategy to ShardingStrategy.HYBRID_SHARD or _HYBRID_SHARD_ZERO2
Expected behavior
Should provide a way to provide the process group or device_mesh as they're parameters for FSDP https://pytorch.org/docs/2.2/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel