alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.06k stars 353 forks source link

[FEATURE] Reduce congestion of sending on one mesh #802

Open ZYHowell opened 1 year ago

ZYHowell commented 1 year ago

This can be a starting point to learn runtime_emitter and cross_mesh_resharding.

Background

In Pipeshard Parallel, when a tensor is required to be received from a mesh, we always chose the mesh that exactly generates it, has what happens here. However, when the tensor is consumed into multiple pipeline stages, a better solution MIGHT be that the later consumer receives the tensor from one of the prior consumers. For example, when stage 0 sends a tensor to stages 1, 2, and 3, but stages 1-3 don't have much communication, then it can be stage 2 receives from stage 1, and stage 3 receives from stage 2.

TODO

jon-chuang commented 1 year ago

Hello, please assign this to me :)

To my understanding, Ray's object store is not used for activation/param transfer, and Ray is only used for task orchestration, correct?

ZYHowell commented 1 year ago

right. The compilation is: wrapped Jaxprs of each pipeline stage --by CrossMeshCommunicator--> SymbolicReshardingTask --by PipelineInstEmitter--> PipelineInstruction(SEND/RECV/BROADCAST) Each pipeline inst is orchestrated by code in the collective folder, which finally calls nccl. The congestion issue may be solved by only consider the compilation part.

jon-chuang commented 1 year ago

For the first task, writing the pass, I will simply write a test to show the desired transform is applied to jaxpr.

As for scheduling, I guess the tensor should be queued instantly upon being received.

jon-chuang commented 1 year ago

One complication: are all vars uniquely named in HLO module, i.e. SSA?

ZYHowell commented 1 year ago

For Jaxpr, each Var has its own id(an int) unless its a DropVar(a placeholder), I'd expect most work are at this level; For HLO, each var corresponds to a specific HloInstruction

jon-chuang commented 1 year ago

Hello, another question: are we guaranteed that the stages are sequentially dependent? Meaning that we have a chain, not a DAG? It doesn't affect too much, but presumably, for DAG structure:

stage 0 -> stage 1
        -> stage 2

Where there is no functional dependence of stage 2 on stage 1, we should indeed broadcast to stage 1 and stage 2 from stage 0 to prevent any stalls.

However, perhaps we can ignore it for now.

ZYHowell commented 1 year ago

it's mainly sequential, but will have some skip connection(e.g. stage 0 -> stage 2, stage 0 -> stage 3, etc.). Otherwise we wouldn't have this issue. Besides, there are both forward and backward stages, so stage 0 and stage -1 are on the same mesh

jon-chuang commented 1 year ago

each Var has its own id

I presume this is unique, so I will use it as var uuid to act as a lookup key.

ZYHowell commented 1 year ago

you can just use var. It wraps such an id

jon-chuang commented 1 year ago

Another question:

Given var x on stage 0 and consumed by stage 1, but not output by stage 1, do we need to now add var x to the outvars of stage 1 to be consumed from stage 1 by a downstream stage 2?

Further, is there any cost to adding all invars to outvars of every stage by default (except messiness)?

ZYHowell commented 1 year ago

it depends on your algo. I think the first principle is to not increase the total comm size. E.g. if originally we send 0>2, I cannot see any advantage in making it 0>1>2. The case in the issue is: 0>1>2 is better than (0>1 and 0>2). In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.

Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars. Besides, the messiness itself might influence later passes so we'd hope to avoid it.

jon-chuang commented 1 year ago

The algo is a simple one. It is take from last seen stage:

last_seen = {}
# Sequentially walk the stages
for stage in stages:
  for (src, var) in cross_deps[stage.id]:
    # If var is a dep, check if we have already read from it.
    # If so, add to outvars of that stage and fetch from the latest stage.
    if var in cache:
      src_mesh = meshes[last_seen[var]]
      upsert(src_mesh.outvars, var)
      last_seen[var] = stage.id
    else:
      last_seen[var] = stage.id
      src_mesh = src

Is adding to outvars necessary? It seems that in our case, we don't need to add to outvars, we should be able to fetch from the invars?

invars -> [model] -> outvars ==cross-shard==>
 |======================cross-shard=>

This would mean that the cross-shard invars can begin to be re-sharded prior to the model invocation.

However, not sure if outvars is merely logical, and we can facilitate the same async transfer as soon as one of the outvars is ready, as marked by the runtime.

Adding invars to outvars makes them live longer, and some invars are designed to donate their memory to corresponding outvars.

I will avoid this then.

ZYHowell commented 1 year ago

The heuristic works for the first scene. In the above (0>1 & 0>2) case, we don't need to add it in 1's outvars. You can read the PipelineInstEmitter for more details how we actually launch send/recv and free. We've already done the async transfer with cuda event and some kernel injected to record the event.

jon-chuang commented 1 year ago

Btw, as far as producer goes, every var corresponds to e.g. a single tensor sharded across the entire submesh, correct?

Anw, adding an invar to the outvars is non-trivial. One has to deal with donation, and also might need to recompile the jaxpr. Prefer if the transfer takes a different pathway to piggybacking on outvars. Any suggestions?

jon-chuang commented 1 year ago

In addition, if 2 sends x to 0 and 1 sends y to 0, but 0 only uses x + y, we can make it be 2 sends x to 1 and 1 sends x+y to 0.

This seems more complicated. For an SQL database person, it sounds like expression pushdown.

Sounds like we really do want to reconfigure the jaxpr after pipelining but before sharding. So at the pipelining pass, we should use our last seen stage heuristic to force relevant invars to become outvars. Not sure if behaviour should be to skip invar if it is donated before this pass.

We've already done the async transfer with cuda event and some kernel injected to record the event.

However, I don't understand how the async queueing would occur in this case. Will every jaxpr outvar be evaluated and queued async concurrently?

ZYHowell commented 1 year ago

A var corresponds to a logical tensor including all its shards.

In the "pipeline pass", we only decide how the computational graph is divided into pipeline stages, but not the communication between pipeline stages.

Instead, in PipelineInstEmitter, we create a schedule of each device mesh, where we manages the execution of each communication, computation, and memory deallocation(tensors are allocated only with computation, and communication reuses those allocated tensors to receive). At there, we store the live tensors of each mesh in PipelineInstEmitterHelper at each time tick.

So the way I think about the forward part is: for the forward part, we only modify code in PipelineInstEmitter to emit send/recv from a mesh with the least traffic among all meshes having the tensor. For the backward part, things are more complicated, there might be something related to the "pipeline pass"

I'd suggest you read the following for more details: the architecture section in the project's doc and https://github.com/alpa-projects/alpa/blob/1ddb2dc30575b218a38937326682e020881dbc8e/alpa/pipeline_parallel/runtime_emitter.py#L545-L591.

For overlapping communication and computation, please refer to https://github.com/alpa-projects/alpa/pull/773 and https://github.com/alpa-projects/tensorflow-alpa/pull/127