NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Build Segments for User Schedule Segmentation #3334

Closed rdspring1 closed 1 week ago

rdspring1 commented 2 weeks ago

General Overview of Segmentation:

Segmentation decomposes a fusion into a directed acyclic graph (DAG) of sub-fusions. After applying the segmentation algorithm, we can translate the sub-fusions into their corresponding python definitions. Then, given the fusion's input arguments, the segments are run in the correct order to produce the output results.

The original FusionDefinition stores the sequence of sub-fusions and acts as an argument manager. It gathers the input arguments before running the sub-fusion and stores its results. To perform this function, it requires a map from the segment index space to the original index space. This mapping is generated while creating the python definition for each sub-fusion.

CPP functions:

Step 1: setupSegmentation runs the segmentation algorithm on the CPP Fusion to create the SegmentedFusion. Then, sub-fusions are ordered according to their dependencies by the prepareGroupOrder function. It returns the number of segments in SegmentedFusion.

Step 2: buildSegment creates the CPP Fusion for a given segment id, translates it to a python FusionDefinition, then returns a mapping from the segment fusion state indices to the original fusion state indices.

Step 3: finalizeSegmentation destroys any state stored in FusionDefinition.

Python functions:

  1. setupSegmentation, buildSegment, and finalizeSegmentation are called together in FusionDefinition.segment.
  2. If a python FusionDefinition has segments, call _execute_segments in the FusionDefinition.execute. The original FusionDefinition acts as argument manager, running the sub-fusions in topological order.

Example:

Original Fusion: A reduction + broadcast + pointwise fusion.

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T1 = fd.define_tensor(shape=[-1, -1],
                          contiguity=[True, True],
                          dtype=DataType.Float,
                          is_cpu=False)
    T2 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
    T3 = fd.ops.broadcast(T2, is_broadcast_dim=[False, True])
    T4 = fd.ops.add(T1, T3)
    fd.add_output(T4)

After Segmentation:

First Segment:

def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.ops.sum(T0, dims=[1], keepdim=False, dtype=DataType.Float)
   T2 = fd.ops.broadcast(T1, is_broadcast_dim=[False, True])
   fd.add_output(T2)

Second Segment:

def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
   T0 = fd.define_tensor(shape=[-1, -1],
                         contiguity=[True, True],
                         dtype=DataType.Float,
                         is_cpu=False)
   T1 = fd.define_tensor(shape=[-1, 1],
                         contiguity=[True, None],
                         dtype=DataType.Float,
                         is_cpu=False)
   T2 = fd.ops.add(T0, T1)
   fd.add_output(T2)

Changes in this PR

This PR implements setupSegmentation function for user-scheduler segmentation. It is the first PR in a stack, followed by https://github.com/NVIDIA/Fuser/pull/3335 and https://github.com/NVIDIA/Fuser/pull/3025.

  1. Create SegmentationState class that contains all segmentation logic for python-frontend.
  2. All segmentation logic is contained in a separate file - csrc/python_frontend/segmentation.h
  3. FusionDefinition contains an instantiation of SegmentationState and exposes its logic in a public interface. This interface is added to the python bindings.
  4. Created test_segmentation_reduction_pointwise_epilogue to test functionality.
rdspring1 commented 2 weeks ago

!test

rdspring1 commented 1 week ago

!test