awslabs / slapo

A schedule language for large model training
https://awslabs.github.io/slapo/
Apache License 2.0
141 stars 15 forks source link

Enquiries about Parameter Sharding #98

Open keneoneth opened 1 year ago

keneoneth commented 1 year ago

Hello there, I have been reading your research paper "Decoupled Model Schedule for Deep Learning Training". In particular, in part (3) tensor parallelism, it is mentioned that "Since the output tensor only holds partial results after sharding, we need to conduct all_reduce to aggregate outputs from different device". May I know which part of the source code at this repository is performing the all_reduce operation? Right now I am taking a look at build.py, and I believe that the captured code below is handling the sharded parameters and making a split when a shard is added by the user. But I am not exactly sure where the all_reduce operation will be done? Any help will be appreciated. Thank you.

# Only keep the partition for this device for sharded params.
        tp_rank = sch.rank
        cnt_shard = 0
        for param_name, param in sch.mod.named_parameters(recurse=False):
            is_found = False
            for idx, new_size in enumerate(new_param_shapes[param_name]):
                if new_size != param.shape[idx]:
                    assert not is_found, "Cannot have two sharded dimensions!"
                    sharded_size = new_size
                    axis = idx
                    is_found = True
            if is_found:
                cnt_shard += 1
                sharded_param = param.detach().split(sharded_size, dim=axis)[tp_rank]
                sharded_param = sharded_param.contiguous()
                new_param = nn.Parameter(sharded_param)
                sch.mod.register_parameter(param_name, new_param)
                transfor_param_tags(sch, param, new_param)
chhzh123 commented 1 year ago

Thanks for your interest in our project! We have already encapsulated those collective operators in our schedule language, so users only need to write the schedule primitives in order to conduct all_reduce, all_gather, etc. Please check the following schedule of BERT, which shards the MLP layer and uses all_reduce to aggregate the outputs. https://github.com/awslabs/slapo/blob/v0.0.3/slapo/model_schedule/bert.py#L303-L313

keneoneth commented 1 year ago

Thanks for replying! I've further studied the code and there are some follow-up questions I would like to ask.

1) What exactly is the difference between "fwd_pre" and "fwd_post" when doing a sync? It seems in the paper it is divided to forward/backward only.

2) I am thinking about whether the all_reduce sync after sharding can be automated instead of being added by the user. Becoz it seems Slapo can infer whether the shard is a 'partition' or 'partial'. And if I understand correctly, only 'partial' would require a follow-up all_reduce sync. Plus, it seems that normally the all_reduce sync with mode=backward will be added to the top of the stack of linear layers, whereas the all_reduce sync with mode=forward will be added to the bottom.

3) The comment on https://github.com/awslabs/slapo/blob/v0.0.3/slapo/sharding/shard_ops.py#L200-L202 looks a bit confusing to me. Would you mind elaborating a bit on why a bwd_post sync should be registered as a forward pre hook?

Thanks a lot!

chhzh123 commented 1 year ago
  1. It depends on whether the synchronization happens before or after the forward function of a module. We follow the "global hooks for modules" in PyTorch to design the API. For example, adding a "fwd_post" sync op for a linear layer is the same as adding a "fwd_pre" op for a consequential activation layer.
  2. Right, this is correct. We are currently exposing those sync APIs to users, so users can have better control on their model. One of our ongoing efforts is to automatically partition the model and conduct communication across multiple devices. After this feature is added, even non-expert users can enable efficient parallelism without writing those .shard or .sync primitives. Please stay tuned.
  3. You can see from the "global hooks for modules" page. There are both register_module_backward_hook and register_module_full_backward_hook, and the semantics of those hooks are not clear to us. We experience some convergence issues when integrating them into our framework. Thus, as a workaround, we create autograd functions and insert them before the module, so it can serve the same functionality in the backward pass. One example can be seen here. We customize the backward function, so registering a forward hook still enables a bwd_post sync.