NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
259 stars 51 forks source link

Improve performance on RoPE (and code around it). #1597

Open wujingyue opened 9 months ago

wujingyue commented 9 months ago

I'm separating this from #1502. While we can get rid of cat in some cases, improving nvFuser's codegen for slice and cat will still benefit the RoPE module and the QKV split around it.

https://github.com/NVIDIA/Fuser/blob/bug1597/qkv_split_rope.py has the fusion definition for the forward pass.

Below is a whiteboard illustration for convenience:

image

Legend:

wujingyue commented 9 months ago
$ NVFUSER_DUMP=segmented_fusion pytest qkv_split_rope.py -s > stdout 2>&1 

$ grep 'g{(' stdout 
g{(no_op)
g{(pointwise)
g{(no_op)
g{(no_op)
g{(pointwise)
g{(pointwise)
g{(pointwise)
g{(pointwise)

Currently, the fusion is segmented into eight kernels. The alias stuff made three of them no-op but the rest five contain non-meta ops like cast, neg, mul, and add. I think we'll have to fix our segmentation around slicing.

wujingyue commented 9 months ago

nvFuser enforces resizing ops like Slice and Cat to be fusion inputs: https://github.com/NVIDIA/Fuser/blob/cb4bd59e1a75e1c6bcaf19ee18a6eb53bf0d8de3/csrc/scheduler/registry_utils.cpp#L193. That explains the many segments. (However, I'd expect 4 not 5)

In another thread, @naoyam explained why resizing ops have to be inputs at this moment, which I copied below:

By default, we only allow those ops with inputs being fusion inputs. The reason is that in our current parallelization scheme, an input to a silce op and its output, when parallelized, may require communications between threads, which can be between threads of different thread blocks, which in turn requires grid synchronizations. For example:

TensorView* t0 = makeConcreteTensor({128});
auto t1 = set(t0);
auto t2 = t1[1:-2];

t2->split(0, 32);
propagateTransformationFrom(t2);

t2->axis(0)->parallelize(BIDx);
t2->axis(1)->parallelize(TIDx);
propagateParallelizationFrom(t2);

After this the t1 and t2 tensors would look like:

t1: [BIDx(128/32), TIDx(32)]
t2: [BIDx(126/32), TIDx(32)]

This would mean that, for example, in the second thread block, the first two threads would need to read data that are held by the first thread block, which is possible but requires a grid synchronization, meaning the number of thread blocks can't exceed the number of SMs. So, unless the fusion launches a small number of thread blocks, it would fail at the launch time. Thus, by default, we always segment a fusion at those ops so that the op inputs are always fusion inputs. You could also try NVFUSER_ENABLE=memory_promotion. It would not segment those ops but may result in using grid synchronizations.

wujingyue commented 9 months ago

NVFUSER_ENABLE=memory_promotion doesn't work out of the box.

test_rope failed with

FAILED qkv_split_rope.py::test_rope - RuntimeError: !fallback_mode_enabled_ INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/pass/expr_sort.cpp":1096, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Couldn't succcessfully sort out the fusion expressions. There are remaining connections of the heirarchical segmentation which should have been flattened to a single ordered group, or disjoint ordered groups.

test_qkv_split_rope failed with

FAILED qkv_split_rope.py::test_qkv_split_rope - RuntimeError: producer->getMemoryType() == MemoryType::Global INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":763, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV62 (T62_l[ iblockIdx.y1125{( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T0) )).logical_size ))[0] ) * ( (( (( getMetaData(T0) )).logical_si...

(The reproducer can be found in the development branch attached to this issue.)

cc @naoyam who might know where to start debugging...

naoyam commented 9 months ago

Hmm, the expression sorting is more scary to debug..., so first, about the second one, that seems to indicate that we fail to annotate a tensor with a correct memory type. In this case, the producer and the consumer tensors have different parallelizations with BID, so we need the producer to be placed on global memory and insert a grid sync. In the schedulers we automatically do this memory placement for certain operations, but it seems that's not working as intended for this case. See promoteMemoryTypes.

As for the expression sorting, I'm sorry I don't have any immediate hint. Typically we would need to dive into the sorting process, which is quite complicated (though based on the same baseline algorithm with segmentation)

wujingyue commented 9 months ago

The following broadcast triggered the producer->getMemoryType() == MemoryType::Global error.

T20_l[ iblockIdx.y1086{( ceilDiv(( ceilDiv(( 1 * ( i0 * i1 ) ), 8) ), blockDim.x) )}, bblockIdx.x1088{( ceilDiv(1, 1) )}, bUS1089{1}, iS1085{8}, ithreadIdx.x1087{blockDim.x} ] ca_pos( 5 )
   = broadcast( T62_l[ iblockIdx.y1125{( ceilDiv(( ceilDiv(( i0 * i1 ), 8) ), blockDim.x) )}, iS1124{8}, ithreadIdx.x1126{blockDim.x} ] )

The first dimension of T20_l and the first dimension of T62_l didn't pass the useSameIndex check at https://github.com/NVIDIA/Fuser/blob/7d576b49edf9e87b7fd355de23e528f2af51bd78/csrc/device_lower/analysis/sync_information.cpp#L755

iblockIdx.y1086{( ceilDiv(( ceilDiv(( 1 * ( i0 * i1 ) ), 8) ), blockDim.x) )}) is mathematically the same as iblockIdx.y1125{( ceilDiv(( ceilDiv(( i0 * i1 ), 8) ), blockDim.x) )}. However, the former contains a 1 *, failing useSameIndex.

I suspect the scheduler did something uncommon for this example. I usually don't see this 1 * from the pointwise-scheduled fusion. I'll keep debugging next week...

PS: you can dump the full log by running:

$ NVFUSER_ENABLE=memory_promotion NVFUSER_DUMP=fusion_ir_math python -m pytest qkv_split_rope.py -k test_qkv_split_rope -s
naoyam commented 9 months ago

The following broadcast triggered the producer->getMemoryType() == MemoryType::Global error.

T20_l[ iblockIdx.y1086{( ceilDiv(( ceilDiv(( 1 * ( i0 * i1 ) ), 8) ), blockDim.x) )}, bblockIdx.x1088{( ceilDiv(1, 1) )}, bUS1089{1}, iS1085{8}, ithreadIdx.x1087{blockDim.x} ] ca_pos( 5 )
   = broadcast( T62_l[ iblockIdx.y1125{( ceilDiv(( ceilDiv(( i0 * i1 ), 8) ), blockDim.x) )}, iS1124{8}, ithreadIdx.x1126{blockDim.x} ] )

The first dimension of T20_l and the first dimension of T62_l didn't pass the useSameIndex check at

https://github.com/NVIDIA/Fuser/blob/7d576b49edf9e87b7fd355de23e528f2af51bd78/csrc/device_lower/analysis/sync_information.cpp#L755

iblockIdx.y1086{( ceilDiv(( ceilDiv(( 1 * ( i0 * i1 ) ), 8) ), blockDim.x) )}) is mathematically the same as iblockIdx.y1125{( ceilDiv(( ceilDiv(( i0 * i1 ), 8) ), blockDim.x) )}. However, the former contains a 1 *, failing useSameIndex.

I suspect the scheduler did something uncommon for this example. I usually don't see this 1 * from the pointwise-scheduled fusion. I'll keep debugging next week...

PS: you can dump the full log by running:

$ NVFUSER_ENABLE=memory_promotion NVFUSER_DUMP=fusion_ir_math python -m pytest qkv_split_rope.py -k test_qkv_split_rope -s

Hmm, this sort of cases are exactly what the analysis is supposed to be able to reason about, but :shrug:

IdModel is meant to be replace this analysis but not yet.

wujingyue commented 9 months ago

How about revisiting this after IdModel kicks in, @naoyam ? Without that, I think the closest I can do is to teach ComputeAtMap dealing with trivial equivalence like 1*x=x. I could also investigate where this 1x comes from in the first place. I don't yet find myself efficient debugging these...

naoyam commented 9 months ago

How about revisiting this after IdModel kicks in, @naoyam ? Without that, I think the closest I can do is to teach ComputeAtMap dealing with trivial equivalence like 1*x=x. I could also investigate where this 1x comes from in the first place. I don't yet find myself efficient debugging these...

Yes, unless it's urgent, it would make more sense to use IdModel.

kevinstephano commented 5 months ago

I want to revive this issue and investigate what we think would be an optimal 1-kernel solution so we understand what an IdModel based solution needs to achieve. Therefore, I asked Jie to investigate what a 1-Kernel cuda solution might be outside of nvFuser. The other question I was curious about is what else do we need to worry about beside LdModel usage for slice and cat? Do we need to fix vectorization as well, for instance?

Christian gave his illustration of the operation here: Image from iOS

He also gave a simple script from Ivan that has a basic lit-gpt example with ROPE and the following configuration.

This is the configuration that is appropriate:

class Config:
    name = "Mistral-7B-v0.1"
    n_layer = 1  # NOTE: Use just one transformer block! Should be 32 for real 7B model
    ### n_embd, intermediate_size, n_query_groups, n_head, head_size have impact on the matmul sizes
    n_embd = 512
    intermediate_size = 3*1024
    n_query_groups = 4
    n_head = 16
    head_size = 32
    ###
    norm_eps = 1e-05
    bias = False
    lm_head_bias = False
    block_size = 1024
    padded_vocab_size = 4*1024
    rope_n_elem = 8
    bsz = 2
jjsjann123 commented 5 months ago

quick note here for myself.

  1. intitial qkv split_reshape resulted in a memcpy. We could postpone it to save a copy. But that's going to require some analysis figuring out that it's safe to pass it past the concat.
  2. alias analysis isn't aggressively setting segmenter hint. In this example, the slice and multiply ended up being grouped by segmenter, giving us an extra kernel (because of data dependency, it cannot be merged with the final add before concat.
  3. the requirement of concat's input must be the input to the kernel -> means that without moving concat out, our best effort is 3 kernels for computing q from qkv.
jjsjann123 commented 4 months ago

Following the concept of 1 kernel for rope:

nvfuser handling of cat is pad + add, where pad introduces resize which in term requires kernel segmentation. However, if we break cat into pad & add, we can theoretically propagate the pad operations back along some pointwise operation and back the the beginning of the fusion. Pushing it further we should be able to even merge pad operations together.

A prototype nvfuser fusion definition with that optimization, we are generating a single kernel from this example: https://gist.github.com/jjsjann123/2c4db9f6659cfe2cc8aa9503cb8a806c

Performance isn't ideal, our achieved bandwidth is low. (~300GB/s and that's double counting some input buffer, see note 2 below) 1: looking at generated kernel, we are getting vectorized write to gmem, but not on load (padded tensors are accessed via ternary ?: op and not vectorized. cc'ing @naoyam

  1. there's likely optimization opportunities here. (q_left / q_left_cos) are the same buffer, but they would not present as such after segmentation. I'm hoping this would just hit L2 in the kernel, but that remains to be verified. Note that we can certainly add more analysis in graph mutation to recognize alias buffer here as well.
csarofeen commented 4 months ago

@jjsjann123 could you please turn this into a minimal example and show the full nvFuser IR before sending it into scheduling and the resulting CUDA code after scheduling it?

Would also be good to have real timers and a validation function.

I forget how pad, resize, and concat translate into CUDA. A simple minimal example would be helpful to think through some basic details of slice/concat.

jjsjann123 commented 4 months ago

@jjsjann123 could you please turn this into a minimal example and show the full nvFuser IR before sending it into scheduling and the resulting CUDA code after scheduling it?

Per @csarofeen 's requirement: https://gist.github.com/jjsjann123/0e0d76b162628b34495f0f5bd0422e17

Note that i didn't put in timers... since this is a tiny kernel and measurement on python side isn't accurate for kernel performance.

jjsjann123 commented 3 months ago

linking PR: #2373 #2490

With the preseg optimization, I'm seeing a single kernel from the rope function below. baby steps~~

import torch
import thunder

# operations to prepare q, k before sending it into rope
def split_qkv(x, n_head, n_query_groups, head_size):
    (
        B,
        T,
        C,
    ) = x.size()
    q_per_kv = n_head // n_query_groups
    total_qkv = q_per_kv + 2
    qkv = x.view(
        B, T, n_query_groups, total_qkv, head_size)
    qkv = qkv.permute(0, 2, 3, 1, 4)
    q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
    k = k.expand(B, n_query_groups, q_per_kv, T, head_size)
    v = v.expand(B, n_query_groups, q_per_kv, T, head_size)
    q = q.reshape(B, -1, T, head_size)
    k = k.reshape(B, -1, T, head_size)
    v = v.reshape(B, -1, T, head_size)
    return q, k, v

def rope_one_entry(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int) -> torch.Tensor:
    x_rope = x[..., : rope_n_elem]
    x1 = x_rope[..., : rope_n_elem // 2]  # (B, nh, T, hs/2)
    x2 = x_rope[..., rope_n_elem // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x_rope * cos) + (rotated * sin)
    roped.to(dtype=x.dtype)
    return torch.cat((roped, x[..., rope_n_elem :]), dim=-1)

dtype = torch.bfloat16
device = "cuda"
bsz = 2
block_size = 1024
n_head = 16
head_size = 32
n_query_groups = 4
rope_n_elem = 8

x = torch.randn([bsz, block_size, (n_head + 2 * n_query_groups) * head_size], device=device, dtype=dtype)
# note cos/sin could have a non-contiguous inner dimension.
cos = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
sin = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)

# we only need to look at a single entry for fusion
# and we are not looking at operations prior to apply_rope
q, _, _ = split_qkv(x, n_head, n_query_groups, head_size)
thunder_rope_one = thunder.jit(rope_one_entry, nv_enable_bookend=False)
o = thunder_rope_one(q, cos, sin, rope_n_elem)
print(thunder.last_traces(thunder_rope_one)[-1])

o_ref = rope_one_entry(q.float(), cos.float(), sin.float(), rope_n_elem).to(dtype=dtype)
assert(o.allclose(o_ref))
jjsjann123 commented 3 months ago

The example script above has regressed performance comparing to earlier manual implementation.

as commit b2fca304564ac4e8740d1d903ffc2cc8ea9de65f

Measuring number as

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     51.1           181344          5   36268.8   35392.0     35296     39872       2014.8  <unnamed>::nvfuser_pointwise_f0_c1_r0_g4(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…

comparing to current main branch, which gives 4 nvfuser kernel

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     41.0           174754          5   34950.8   34912.0     34721     35392        259.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g4(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…
      6.7            28703          5    5740.6    5728.0      5727      5760         17.7  <unnamed>::nvfuser_pointwise_f0_c1_r0_g5(<unnamed>::Tensor<<unnamed>::__bfloat, (int)2, (int)2>, <u…
      6.7            28608          5    5721.6    5664.0      5632      6016        165.4  <unnamed>::nvfuser_pointwise_f0_c1_r0_g2(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…
      4.8            20352          5    4070.4    4032.0      4032      4192         69.4  <unnamed>::nvfuser_pointwise_f0_c1_r0_g3(<unnamed>::Tensor<<unnamed>::__bfloat, (int)2, (int)2>, <u…

============topic segmenter==============

While previously we have in the manual implementation here:

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     14.8            18144          1   18144.0   18144.0     18144     18144          0.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g6(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…

That's a 2x perf regression, which is a bit surprising. Then I realized I accidentally used static shape in the reference implementation and if I switch to dynamic shape the perf is much worse than in the PR.

 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     42.3            56640          1   56640.0   56640.0     56640     56640          0.0  <unnamed>::nvfuser_pointwise_f0_c1_r0_g6(<unnamed>::Tensor<<unnamed>::__bfloat, (int)4, (int)4>, <u…