NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

[RFC] Multi-Gpu Python Frontend API #3094

Open rdspring1 opened 1 month ago

rdspring1 commented 1 month ago

🚀 The feature, motivation and pitch

RFC: Multi-Gpu Python Frontend API

This RFC compares and contrasts some ideas for exposing multi-gpu support in the python frontend.

  1. The current multigpu_schedule approach.
  2. Exposing communication primitives in python frontend and manually scheduling them in fusion definition.
  3. The PyTorch DTensor API.

Current Multi-GPU support in NvFuser.

  1. In python-frontend, define multidevice_schedule function to create device mesh and to apply ParallelType layout.
  2. During pre-segmentation, propagate DeviceMesh and ParallelType.
  3. During segmentation, split compute and communication fusions.
  4. During compileFusion, translate Reduction and LoadStoreOp with deviceMesh into communication expressions. Then, create HostIRContainer for communication expressions.
  5. During runFusion, run CUDA kernel or communication primitives.

DTensor in NvFuser.

References:

API Example:

import torch
import torch.distributed._tensor as dtensor
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor 

x = torch.randn(shape)
distx = DTensor.from_local(x, device_mesh, layout)

Description:

In PyTorch, a DTensor is a tensor with a device mesh and a layout.

In NvFuser, a DTensor is a TensorView with a device mesh. The layout is specified by setting ParallelType::DIDx, ParallelType::DIDy, and ParallelType::DIDz on some IterDomains.

Apply propagation rules through operations for DTensors in FusionDefinition.


Manual Multi-Gpu Definition.

Why expose communication expressions in python-frontend?


What is a multi-gpu matmul?

Goal: Compute C[M, N] = A[M, K] @ B[K, N] using a mesh of devices.

  1. Shard A and B input matrices according to C output matrix. sA is row-wise sharded. sB is col-wise sharded.

  2. Apply matmul given A and B shards. sC[sM, sN] = sA[sM, K] @ sB[K, sN]

  3. Gather C output shards to get full C output matrix. sC is gathered from all devices to create C matrix.


Multidevice Schedule:

class FusionDefinition:
    def definition(self):
        A = fd.from_pytorch(ptA)
        B = fd.from_pytorch(ptB)
        C = fd.ops.matmul(A, B)
        fd.add_output(C)

    def multigpu_schedule(self):
       # create device mesh.
       # apply ParallelType.mesh_x, ParallelType.mesh_y, and ParallelType.mesh_z to create layout.

Manual Scheduling with Multidevice Schedule:

class FusionDefinition:
    def definition(self):
        A = fd.from_pytorch(ptA)
        B = fd.from_pytorch(ptB)
        C = fd.ops.matmul(A, B)
        fd.add_output(C)

    def multigpu_schedule(self):
       shard_A = self.schedule.cache_after(A, LoadStoreOpType.scatter)
       shard_B = self.schedule.cache_after(B, LoadStoreOpType.scatter)
       shard_C = self.schedule.cache_before(C, LoadStoreOpType.gather)

       # manually apply device mesh and layout to fusion.
       for t in self.all_tensors():
           # create device mesh
           # apply ParallelType.mesh_x, ParallelType.mesh_y, and ParallelType.mesh_z to create layout.

Manual Scheduling in definition:

class FusionDefinition:
    def definition(self):
        A = fd.from_pytorch(ptA)
        B = fd.from_pytorch(ptB)

        # Create mesh
        mesh = fd.create_mesh(...)

        # Create layout for DTensor
        rowwise_placement = [Shard(0)]
        colwise_placement = [Shard[1])

        # Scatter to shard input tensors across device.
        shards_A = scatter(A, mesh, rowwise_placement)
        shards_B = scatter(B, mesh, colwise_placement)

        # Apply Matmul across shards
        shards_C = [fd.ops.matmul(shard_A, shard_B) for shard_A, shard_B in zip(shards_A, shards_B)]

        # Gather sharded tensors to get output tensor
       C = gather(shards_C, mesh)
       fd.add_output(C)

DTensor:

class FusionDefinition:
    def definition(self):
        A = fd.from_pytorch(ptA)
        B = fd.from_pytorch(ptB)

        # Create mesh
        mesh = fd.create_mesh(...)

        # Create layout for DTensor
        rowwise_placement = [Shard(0)]
        colwise_placement = [Shard[1])

        # Create Dtensor input tensors from local tensors.
        dtA = fd.ops.create_dtensor(A, mesh, rowwise_placement)
        dtB = fd.ops.create_dtensor(B, mesh, colwise_placement)

        # Apply matmul on DTensor.
        # dtC gets layout from sharding propagation rules.
        dtC = fd.ops.matmul(dtA, dtB)

        # Get local tensor from output dtensor.
        C = fd.ops.to_local(dtC)
        fd.add_output(C)

cc @wujingyue @kevinstephano

wujingyue commented 1 month ago

Current Multi-GPU support in NvFuser.

Yes, that's the right summary.

wujingyue commented 1 month ago

Why expose communication expressions in python-frontend?

Is that a question for me? We don't expose communication expressions in the python frontend. For example, https://github.com/NVIDIA/Fuser/blob/4c9090e2a2c57fcffa9b928b8995e426d345d134/tests/python/test_multidevice.py#L43-L49 is built without set/reduce. Users only need to give some (but not all) tensors a mesh and parallel types, and sharding propagating takes care of the rest.

wujingyue commented 1 month ago

This is the current implementation.

class FusionDefinition:
    def definition(self):
        A = fd.from_pytorch(ptA)
        B = fd.from_pytorch(ptB)
        C = fd.ops.matmul(A, B)
        fd.add_output(C)

    def multigpu_schedule(self):
       shard_A = self.schedule.cacheAfter(A)
       shard_B = self.schedule.cacheAfter(B)
       shard_C = self.schedule.cacheBefore(C)
       # create device mesh
       # apply ParallelType.mesh_x, ParallelType.mesh_y, and ParallelType.mesh_z to create layout
       # pre-segmentation propagates device mesh and layout through fusion. The sharding propagation rules would be in a single pass rather than per-expression in the DTensor API.

No, it's not the current implementation. One of the preseg passes (insert_resharding) inserts set around a resharding expression, to separate resharding from the math. Therefore, we don't need cacheBefore and cacheAfter.

wujingyue commented 1 month ago

DTensor:

class FusionDefinition:
    def definition(self):
        A = fd.from_pytorch(ptA)
        B = fd.from_pytorch(ptB)

        # Create mesh
        mesh = fd.create_mesh(...)

        # Create layout for DTensor
        rowwise_placement = [Shard(0)]
        colwise_placement = [Shard[1])

        # Create Dtensor input tensors from local tensors.
        dtA = create_dtensor(A, mesh, rowwise_placement)
        dtB = create_dtensor(B, mesh, colwise_placement)

        # Apply matmul on DTensor.
        # dtC gets layout from sharding propagation rules.
        dtC = fd.ops.matmul(dtA, dtB)

        # Get local tensor from output dtensor.
        C = dtC.to_local()
        fd.add_output(C)

This is ballpark what I had in mind. Speaking of the implementation, I'm unsure about creating a separate DistributedTensor class in addition to the existing Tensor class in our Python API. This AFAICT would require us to reimplement fd.ops APIs for DistributedTensors, which is lots of work.

To get started, I'm inclined to have something like GSPMD's mesh_split API, which merely annotates the sharding and embeds the annotation inside the definition. This should be enough given we expect nvFuser's sharding propagation to do most of the heavy lifting.

rdspring1 commented 1 month ago

Why expose communication expressions in python-frontend?

This was a question for myself. Communication expressions blur the line between the math definition and fusion scheduling. I think we should support them in the definition to represent any pytorch script.

wujingyue commented 1 month ago

I think we should support them in the definition to represent any pytorch script.

I'm not sure I follow. We are very incentivized to hide them from the user and thus the definition, because people tend to make mistakes or communicate in a suboptimal way. We might want to expose communication for debugging, but I'd prefer exposing that via host IR.

rdspring1 commented 1 month ago

The only users of the python-frontend directly are ourselves. We should prioritize our productivity.

Do you intend to expose HostIR in python?

If Thunder traces a python program with torch.distributed operations, how would NvFuser handle them?

What if I want to trace a SOTA implementation like Megatron-Core?

IIUC, our implementation comes from porting their approach to NvFuser. Won't there always be a lag time for supporting their latest research?

crcrpar commented 1 month ago

2. Exposing communication primitives in python frontend and manually scheduling them in fusion definition.

This is exactly what I've wanted from nvfuser. Currently a Thunder trace of ddp/fsdp has all the distributed communication ops in it, namely all-reduce, all-gather, and reduce-scatter, but nvfuser does not have python API for communication ops, the ops are kind of graph break points.

wujingyue commented 1 month ago

If Thunder traces a python program with torch.distributed operations, how would NvFuser handle them?

nvfuser does not have python API for communication ops, the ops are kind of graph break points.

Both of you are asking good questions on how nvFuser and Thunder coplay. It's a large design space that I haven't explored fully, and I'm sure @kevinstephano has better ideas.

I plan to let nvFuser solve the following two problems:

  1. In the short term, take care of only tensor parallelism (including sequence parallelism and context parallelism etc) and offload data parallelism and/or pipeline parallelism to Thunder. This is because (1) Thunder supports DDP and FSDP already and that (2) TP has a larger search space and nvFuser's existing schedulers and optimizations for single-GPU can be reused to solve TP. To make this happen, nvFuser needs a multi-GPU API that's composable to Thunder's.
  2. In the long term, O(years), we should try to let nvFuser take care of all parallelism. This way, the Thunder trace will contain the minimum annotations (e.g. on inputs and weights) and nvFuser will do the heavy lifting. This is high risk and high reward. It's quite cumbersome for even ML perf experts to hand craft a parallelization strategy for every model architecture and their variations.

So to your original questions, I don't plan to let nvFuser take a tensor-parallel Thunder trace instrumented with torch.distributed operations. It's certainly doable but isn't the best investment at this moment. Instead, I think the most immediate goal is to allow nvFuser to take a data-parallel Thunder trace with tensor-parallel annotations. There, DP is implemented using torch.distributed ops (or DTensor?), and the TP intention is represented using some annotations that nvFuser can process further. Will torch.distributed become graph breaks? Yes, but they won't be everywhere dictating all communications so hopefully nvFuser will still have quite some good regions to optimize TP.

crcrpar commented 2 weeks ago
  1. In the short term, take care of only tensor parallelism (including sequence parallelism and context parallelism etc) and offload data parallelism and/or pipeline parallelism to Thunder. This is because (1) Thunder supports DDP and FSDP already and that (2) TP has a larger search space and nvFuser's existing schedulers and optimizations for single-GPU can be reused to solve TP.

the most immediate goal is to allow nvFuser to take a data-parallel Thunder trace with tensor-parallel annotations.

wujingyue commented 2 weeks ago

could you give me a toy example trace of tensor parallelism and data parallelism that satisfies the conditions?

The closest I can find is https://gist.github.com/wujingyue/b111aa8b8d92067fc6004f5d0488dd27, the forward and backward trace for a transformer layer. You can imagine the same traces with inputs being annotated row-wise sharded, column-wise sharded, or replicated.

how a trace look like, if we only use tensor parallelism?

The same traces above but with batch size > 1.

crcrpar commented 2 weeks ago

The inputs need to be annotated with the involved data parallel schemes, and also the trace needs to have torch.distributed.ProcessGroups or DeviceMesh?

wujingyue commented 2 weeks ago

Yes, I believe the trace needs to contain the needed DP constructs because we are talking about combining DP and TP. I'm just unsure about the exact format. Would you mind sending me a DDP'ed trace and/or teaching me how I can generate one? This'll help me think more concretely.

crcrpar commented 2 weeks ago

how I can generate one?

I'd use https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/benchmarks/benchmark_litgpt.py and run it with torchrun --nproc-per-node <NGPUS> benchmark_litgpt.py --distributed_mode fsdp/ddp --compile thunder --dump_thunder_traces true

For ddp, set --ddp_bueckt_size=0 if gradient bucketing for all-reduce isn't desirable.

--dump_thunder_traces true would print the last trace (one to execute).

Just so you know, TensorProxy arguments would have the attribute of distparallel_type indicating DDP, FSDP, or N/A