Open keneoneth opened 1 year ago
Thanks for your interest in our project! We have already encapsulated those collective operators in our schedule language, so users only need to write the schedule primitives in order to conduct all_reduce
, all_gather
, etc. Please check the following schedule of BERT, which shards the MLP layer and uses all_reduce
to aggregate the outputs.
https://github.com/awslabs/slapo/blob/v0.0.3/slapo/model_schedule/bert.py#L303-L313
Thanks for replying! I've further studied the code and there are some follow-up questions I would like to ask.
1) What exactly is the difference between "fwd_pre" and "fwd_post" when doing a sync? It seems in the paper it is divided to forward/backward only.
2) I am thinking about whether the all_reduce sync after sharding can be automated instead of being added by the user. Becoz it seems Slapo can infer whether the shard is a 'partition' or 'partial'. And if I understand correctly, only 'partial' would require a follow-up all_reduce sync. Plus, it seems that normally the all_reduce sync with mode=backward will be added to the top of the stack of linear layers, whereas the all_reduce sync with mode=forward will be added to the bottom.
3) The comment on https://github.com/awslabs/slapo/blob/v0.0.3/slapo/sharding/shard_ops.py#L200-L202 looks a bit confusing to me. Would you mind elaborating a bit on why a bwd_post sync should be registered as a forward pre hook?
Thanks a lot!
.shard
or .sync
primitives. Please stay tuned.
Hello there, I have been reading your research paper "Decoupled Model Schedule for Deep Learning Training". In particular, in part (3) tensor parallelism, it is mentioned that "Since the output tensor only holds partial results after sharding, we need to conduct all_reduce to aggregate outputs from different device". May I know which part of the source code at this repository is performing the all_reduce operation? Right now I am taking a look at build.py, and I believe that the captured code below is handling the sharded parameters and making a split when a shard is added by the user. But I am not exactly sure where the all_reduce operation will be done? Any help will be appreciated. Thank you.