NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
225 stars 42 forks source link

Task parallelism and warp-specialization #210

Open zasdfgbnm opened 1 year ago

zasdfgbnm commented 1 year ago

Motivation

On Hopper, efficient gemm requires warp-specialization, which is not currently supported by nvFuser. This doc is to extend nvFuser in order to support such optimization. I believe this will not only benefit matmul, but also benefit other cases like optimal cat/stack scheduling, horizontal fusion, etc., see "Potential applications" section for more detail.

Design

Notation: I will mostly use the term "task parallelism" for the new thing being added to nvFuser. "Warp-specialization" is a special case of "vertical task parallelism" (described below) on thread index.

Partition of DAG as tasks

In order to use task parallelism, we need first partition the DAG into tasks. Tasks are non-overlapping and dense, that is, every Val in the fusion definition except fusion inputs belongs to a task (fusion inputs are special because they are given instead of computed), and one Val can only belong to one task. Initially, all Vals belong to task 0. Example partition:

image

Grouping tasks into a hierarchical structure

Tasks are further grouped into task groups. Task groups form a hierarchical structure.

image

Parallelization of task groups

A task group can be parallelized, for example

group1->parallelize(ParallelType::TIDy);
group3->parallelize(ParallelType::TIDz);

Not all task groups can be parallelized. A parallelizable group is either a "horizontal group" or a "vertical group". A "horizontal group" is a group whose members have no data dependency with each other. For example, group 1 is a horizontal group. A "vertical group" is a group whose members are connected, for example group 3 is a vertical group.

Below is an example where group 4 is neither a horizontal group nor a vertical group:

image

However, you can make it a horizontal group by grouping group 2 and group 3 together:

image

Expression sorting

Expression sorting must be task and task group aware. For the above example, the sorted expressions can be

group 3

which zoom into

group 1
group 2

which zoom into

task 1
task 2
task 3
task 4
task 5

// the following order is also valid, and expr sort is free to choose one from all valid orders

task 3
task 4
task 1
task 2
task 5

// There are more valid orders...

// In later context, we assume the first order is picked

which zoom into

tv1 = sin(tv0);
tv5 = cos(tv1);
tv2 = tan(tv0);
tv6 = relu(tv2);
tv3 = exp(tv0);
tv7 = sigmoid(tv3);
tv4 = log(tv0);
tv8 = tanh(tv4);
tv9 = cat([tv5, tv6, tv7, tv8]);
tv10 = neg(tv9);

Loop nest generation

When generating loop nest, for an unparallelized group, it just generate its members one after another. For parallelized groups, it will generate kir::IfThenElses to dispatch between its members.

Assuming tv0-tv8 has [BIDx, TIDy{size0}, TIDx], tv9 and tv10 has [BIDx, TIDy{size0*4}, TIDx] (the cat dim is 1, I am assuming the cat dim is untouched by the scheduler in this example), and inline most.

Then the generated loop nest structure will be

FOR blockIdx.x in BIDx:
  IF threadIdx.z == 0:
    IF threadIdx.y >= 0 && threadIdx.y < size0:
      FOR threadIdx.y in TIDy{size0}:
        FOR threadIdx.x in TIDx:
          tv1 = sin(tv0);
          tv5 = cos(tv1);
    ELSE IF threadIdx.y - size0 >= 0 && threadIdx.y - size0 < size0:
      FOR threadIdx.y - size0 in TIDy{size0}:
        FOR threadIdx.x in TIDx:
          tv2 = tan(tv0);
          tv6 = relu(tv2);
    ELSE IF threadIdx.y - 2*size0 >= 0 && threadIdx.y - 2*size0 < size0:
      FOR threadIdx.y - 2*size0 in TIDy{size0}:
        FOR threadIdx.x in TIDx:
          tv3 = exp(tv0);
          tv7 = sigmoid(tv3);
    ELSE IF threadIdx.y - 3*size0 >= 0 && threadIdx.y - 3*size0 < size0:
      FOR threadIdx.y - 3*size0 in TIDy{size0}:
        FOR threadIdx.x in TIDx:
          tv4 = log(tv0);
          tv8 = tanh(tv4);
  ELSE IF threadIdx.z == 1:
    FOR threadIdx.y in TIDy{4*size0}:
      FOR threadIdx.x in TIDx:
        tv9 = cat([tv5, tv6, tv7, tv8]);
        tv10 = neg(tv9);

Synchronization

For the parallelization of horizontal task groups, synchronization must happen before and after the dispatch. Depending on the parallel type, block sync or grid sync might be needed. For the parallelization of vertical task groups (a.k.a. warp specialization), parallelization boundary (in this case tv5-tv8) must be double/circular buffered, and arrive-wait barrier is used for sync.

Potential applications

Efficient matmul on Hopper

See: https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp

Warp specialization is used, and we are doing load and mma+store in different warps.

Horizontal fusion

For example, we have multiple separate fusions, independently scheduled. If all fusions only uses BIDx but not BIDy and BIDz, then we can trivially horizontally fuse these fusions by partitioning each fusion as a task in the combined fusion and horizontal parallelize these tasks on BIDx.

Cat/stack schedule

For cat/stack, the size of the output tensor is naturally the sum of the size of inputs, we could parallelize the computation of inputs in a way like the parallelization of group 1 in the above example.

zasdfgbnm commented 1 year ago

call for review: @naoyam @csarofeen @mmigdal-nv @drzejan2

naoyam commented 1 year ago

The multi-device runtime by @samnordmann seems quite relevant. He's extending the overall design. Maybe we should also consider finer-grained task parallelism like warp specialization.

drzejan2 commented 1 year ago

In case of matmul and use of TensorCores, the horizontal group with tasks will effectively mean partitioning of input tensors into sub-regions and within each of such tasks will be a vertical group with two tasks:

Or this is too low level, and for now by task we mean a higher level operation, not low level memory managment?

samnordmann commented 1 year ago

The multi-device runtime by @samnordmann seems quite relevant. He's extending the overall design. Maybe we should also consider finer-grained task parallelism like warp specialization.

I would be happy to discuss about it. I'll put here some design docs when it will be ready

zasdfgbnm commented 1 year ago

In case of matmul and use of TensorCores, the horizontal group with tasks will effectively mean partitioning of input tensors into sub-regions and within each of such tasks will be a vertical group with two tasks:

  • handling data loading (TMA)
  • process data with TensorCore ?

Or this is too low level, and for now by task we mean a higher level operation, not low level memory managment?

You are right, it is to parallelize TMA and tensor core.