alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.08k stars 357 forks source link

[FEATURE] pass global config to worker and set manual sharding of intermediates #928

Closed ZYHowell closed 1 year ago

ZYHowell commented 1 year ago
  1. Pass the global_config to MeshHostWorker's constructor. This allows a user to specify nccl related configs and more;
  2. Set manual sharding of intermediate tensors. Prior to this, we can only set data parallelism for intermediate stages with global inputs like attention_mask.