Open rdspring1 opened 1 month ago
Current Multi-GPU support in NvFuser.
Yes, that's the right summary.
Why expose communication expressions in python-frontend?
Is that a question for me? We don't expose communication expressions in the python frontend. For example, https://github.com/NVIDIA/Fuser/blob/4c9090e2a2c57fcffa9b928b8995e426d345d134/tests/python/test_multidevice.py#L43-L49 is built without set/reduce. Users only need to give some (but not all) tensors a mesh and parallel types, and sharding propagating takes care of the rest.
This is the current implementation.
class FusionDefinition: def definition(self): A = fd.from_pytorch(ptA) B = fd.from_pytorch(ptB) C = fd.ops.matmul(A, B) fd.add_output(C) def multigpu_schedule(self): shard_A = self.schedule.cacheAfter(A) shard_B = self.schedule.cacheAfter(B) shard_C = self.schedule.cacheBefore(C) # create device mesh # apply ParallelType.mesh_x, ParallelType.mesh_y, and ParallelType.mesh_z to create layout # pre-segmentation propagates device mesh and layout through fusion. The sharding propagation rules would be in a single pass rather than per-expression in the DTensor API.
No, it's not the current implementation. One of the preseg passes (insert_resharding) inserts set
around a resharding expression, to separate resharding from the math. Therefore, we don't need cacheBefore and cacheAfter.
DTensor:
class FusionDefinition: def definition(self): A = fd.from_pytorch(ptA) B = fd.from_pytorch(ptB) # Create mesh mesh = fd.create_mesh(...) # Create layout for DTensor rowwise_placement = [Shard(0)] colwise_placement = [Shard[1]) # Create Dtensor input tensors from local tensors. dtA = create_dtensor(A, mesh, rowwise_placement) dtB = create_dtensor(B, mesh, colwise_placement) # Apply matmul on DTensor. # dtC gets layout from sharding propagation rules. dtC = fd.ops.matmul(dtA, dtB) # Get local tensor from output dtensor. C = dtC.to_local() fd.add_output(C)
This is ballpark what I had in mind. Speaking of the implementation, I'm unsure about creating a separate DistributedTensor
class in addition to the existing Tensor
class in our Python API. This AFAICT would require us to reimplement fd.ops
APIs for DistributedTensors, which is lots of work.
To get started, I'm inclined to have something like GSPMD's mesh_split API, which merely annotates the sharding and embeds the annotation inside the definition. This should be enough given we expect nvFuser's sharding propagation to do most of the heavy lifting.
Why expose communication expressions in python-frontend?
This was a question for myself. Communication expressions blur the line between the math definition and fusion scheduling. I think we should support them in the definition
to represent any pytorch script.
I think we should support them in the definition to represent any pytorch script.
I'm not sure I follow. We are very incentivized to hide them from the user and thus the definition, because people tend to make mistakes or communicate in a suboptimal way. We might want to expose communication for debugging, but I'd prefer exposing that via host IR.
The only users of the python-frontend directly are ourselves. We should prioritize our productivity.
Do you intend to expose HostIR in python?
If Thunder traces a python program with torch.distributed
operations, how would NvFuser handle them?
What if I want to trace a SOTA implementation like Megatron-Core?
IIUC, our implementation comes from porting their approach to NvFuser. Won't there always be a lag time for supporting their latest research?
2. Exposing communication primitives in python frontend and manually scheduling them in fusion definition.
This is exactly what I've wanted from nvfuser. Currently a Thunder trace of ddp/fsdp has all the distributed communication ops in it, namely all-reduce, all-gather, and reduce-scatter, but nvfuser does not have python API for communication ops, the ops are kind of graph break points.
If Thunder traces a python program with torch.distributed operations, how would NvFuser handle them?
nvfuser does not have python API for communication ops, the ops are kind of graph break points.
Both of you are asking good questions on how nvFuser and Thunder coplay. It's a large design space that I haven't explored fully, and I'm sure @kevinstephano has better ideas.
I plan to let nvFuser solve the following two problems:
So to your original questions, I don't plan to let nvFuser take a tensor-parallel Thunder trace instrumented with torch.distributed operations. It's certainly doable but isn't the best investment at this moment. Instead, I think the most immediate goal is to allow nvFuser to take a data-parallel Thunder trace with tensor-parallel annotations. There, DP is implemented using torch.distributed ops (or DTensor?), and the TP intention is represented using some annotations that nvFuser can process further. Will torch.distributed become graph breaks? Yes, but they won't be everywhere dictating all communications so hopefully nvFuser will still have quite some good regions to optimize TP.
- In the short term, take care of only tensor parallelism (including sequence parallelism and context parallelism etc) and offload data parallelism and/or pipeline parallelism to Thunder. This is because (1) Thunder supports DDP and FSDP already and that (2) TP has a larger search space and nvFuser's existing schedulers and optimizations for single-GPU can be reused to solve TP.
the most immediate goal is to allow nvFuser to take a data-parallel Thunder trace with tensor-parallel annotations.
could you give me a toy example trace of tensor parallelism and data parallelism that satisfies the conditions?
The closest I can find is https://gist.github.com/wujingyue/b111aa8b8d92067fc6004f5d0488dd27, the forward and backward trace for a transformer layer. You can imagine the same traces with inputs being annotated row-wise sharded, column-wise sharded, or replicated.
how a trace look like, if we only use tensor parallelism?
The same traces above but with batch size > 1.
The inputs need to be annotated with the involved data parallel schemes, and also the trace needs to have torch.distributed.ProcessGroup
s or DeviceMesh
?
Yes, I believe the trace needs to contain the needed DP constructs because we are talking about combining DP and TP. I'm just unsure about the exact format. Would you mind sending me a DDP'ed trace and/or teaching me how I can generate one? This'll help me think more concretely.
how I can generate one?
I'd use https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/benchmarks/benchmark_litgpt.py and run it with torchrun --nproc-per-node <NGPUS> benchmark_litgpt.py --distributed_mode fsdp/ddp --compile thunder --dump_thunder_traces true
For ddp, set --ddp_bueckt_size=0
if gradient bucketing for all-reduce isn't desirable.
--dump_thunder_traces true
would print the last trace (one to execute).
Just so you know, TensorProxy
arguments would have the attribute of distparallel_type
indicating DDP, FSDP, or N/A
🚀 The feature, motivation and pitch
RFC: Multi-Gpu Python Frontend API
This RFC compares and contrasts some ideas for exposing multi-gpu support in the python frontend.
multigpu_schedule
approach.Current Multi-GPU support in NvFuser.
multidevice_schedule
function to create device mesh and to apply ParallelType layout.DTensor in NvFuser.
References:
API Example:
Description:
In PyTorch, a
DTensor
is a tensor with a device mesh and a layout.In NvFuser, a
DTensor
is a TensorView with a device mesh. The layout is specified by settingParallelType::DIDx
,ParallelType::DIDy
, andParallelType::DIDz
on some IterDomains.Apply propagation rules through operations for
DTensors
in FusionDefinition.Manual Multi-Gpu Definition.
Why expose communication expressions in python-frontend?
What is a multi-gpu matmul?
Goal: Compute
C[M, N] = A[M, K] @ B[K, N]
using a mesh of devices.Shard A and B input matrices according to C output matrix.
sA
is row-wise sharded.sB
is col-wise sharded.Apply matmul given A and B shards.
sC[sM, sN] = sA[sM, K] @ sB[K, sN]
Gather C output shards to get full C output matrix.
sC
is gathered from all devices to create C matrix.Multidevice Schedule:
Manual Scheduling with Multidevice Schedule:
Manual Scheduling in definition:
DTensor:
cc @wujingyue @kevinstephano