NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
256 stars 51 forks source link

group norm segmented into pointwise + persistent + pointwise #2375

Open liqiangxl opened 3 months ago

liqiangxl commented 3 months ago

Group norm is calculated as:

x0 = [N, C, H, W]
x1 = x0.cast(fp32).reshape(N, C, H, W)  --> (N, G, C/G, H, W) 
x2 = x1 / x1.sum(C/G, H, W)
x3 = x2.reshape(N, G, C/G, H, W) --> (N, C, H, W)
x4 = w*x3 + b

Due to the two reshapes, normalization scheduler rejects the unsegmented fusion and then it is segmented into three sub-fusions: (1) pointwise doing cast + reshape (2) normalization (3) pointwise doing reshape + scale & bias

Reproduce (modified from apex group norm implementation and apex group norm test ):

import torch
import thunder

def torch_group_norm(x, g, w, b, eps, act=""):
    xdtype, wdtype = x.dtype, w.dtype
    if xdtype != wdtype:
        x = x.to(dtype=wdtype)
    y = torch.nn.functional.group_norm(x, g, w, b, eps)
    if act in ["silu", "swish"]:
        y = torch.nn.functional.silu(y)
    if xdtype != wdtype and y.dtype != xdtype:
        y = y.to(dtype=xdtype)
    return y

def verify_group_norm(N=32,
                      C=128,
                      H=256,
                      W=256,
                      G=32,
                      xdtype=torch.float16,
                      wdtype=torch.float32,
                      eps=1e-5,
                      memory_format=torch.channels_last,
                      device='cuda',
                      act=""):
    # create data
    x_shape = (N, C, H, W)
    w_shape = (C,)
    weight = torch.rand(w_shape,
                        dtype=wdtype,
                        device='cuda',
                        requires_grad=True)
    bias = torch.rand(w_shape,
                      dtype=wdtype,
                      device='cuda',
                      requires_grad=True)
    x = torch.randn(x_shape, dtype=xdtype, device='cuda')
    x = x.to(memory_format=memory_format)
    x.requires_grad_(True)
    thunder_group_norm = thunder.jit(torch_group_norm)
    y_torch = torch_group_norm(x, G, weight, bias, eps, act)
    y_thunder = thunder_group_norm(x, G, weight, bias, eps, act)
    # compare
    torch.testing.assert_close(y_thunder, y_torch, atol=4e-2, rtol=0)

# NVFUSER_DUMP=scheduler_params,cuda_to_file,fusion_ir_preseg,python_definition python group_norm.py 2>&1 |tee 1.log
if __name__ == "__main__":
  verify_group_norm(N=2, C=128, H=16, W=16)
naoyam commented 3 months ago

I wonder why (1) and (2) are not fused. I think an easy solution for the second reshape would move it to the end of the fusion and turn it to a meta-data operation.

liqiangxl commented 3 months ago

Due to reductionInterferingView check. If I disable this check & remove the last reshape in group norm (no weight & bias), nvFuser uses Inner Persistent but got err msg Merging IterDomains requires that their iteration types match. Outer: iS133{32}, Inner: rS17{i2}. We can also move the 1st reshape to the beginning of the fusion, so it is just a no-op. what do you think?

naoyam commented 3 months ago

Due to reductionInterferingView check. If I disable this check & remove the last reshape in group norm (no weight & bias), nvFuser uses Inner Persistent but got err msg Merging IterDomains requires that their iteration types match. Outer: iS133{32}, Inner: rS17{i2}. We can also move the 1st reshape to the beginning of the fusion, so it is just a no-op. what do you think?

There's only one normalization in (1) and (2), so I assume x1.sum is used as the reference for the segment. Which tensor does this conflict come from?

liqiangxl commented 3 months ago

Looks like thunder did some optimizations of the captured graph. Specifically, it does reshape of input before cast to fp32. Afther computation, it added another reshape before cast back to fp16. The 2nd reshape caused the error.

def augmented_forward_fn(x):
  # x: "cuda:0 f16[2, 128, 16, 16]"
  t0 = torch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
    # t0 = ltorch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
      # t0 = prims.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
  [t16, t5, t9] = nvFusion0(t0)
    # t1 = prims.convert_element_type(t0, dtypes.float32)  # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
    # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
    # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1])  # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1])  # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t8 = prims.add(t6, 1e-05)  # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t12 = prims.sub(t1, t10)  # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t15 = prims.reshape(t14, (2, 128, 16, 16))  # t15: "cuda:0 f32[2, 128, 16, 16]"
    # t16 = prims.convert_element_type(t15, dtypes.float16)  # t16: "cuda:0 f16[2, 128, 16, 16]"
  return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t0, t5, t9), ())
liqiangxl commented 3 months ago

Thanks for the suggestion @jjsjann123 and @kevinstephano, nvFuser gets two reshapes if thunder.jit(torch_group_norm, nv_enable_bookend=False)

def augmented_forward_fn(x):
  # x: "cuda:0 f16[2, 128, 16, 16]"
  [t16, t5, t9] = nvFusion0(x)
    # t0 = prims.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
    # t1 = prims.convert_element_type(t0, dtypes.float32)  # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
    # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
    # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1])  # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1])  # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t8 = prims.add(t6, 1e-05)  # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t12 = prims.sub(t1, t10)  # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t15 = prims.reshape(t14, (2, 128, 16, 16))  # t15: "cuda:0 f32[2, 128, 16, 16]"
    # t16 = prims.convert_element_type(t15, dtypes.float16)  # t16: "cuda:0 f16[2, 128, 16, 16]"
  return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t5, t9, x), (0,))

If we change these two reshapes to reshape of input and reshape before output, and revise MarkAliasesPreparePass, nvFuser can change from 3 kernels to two no ops and a normalization kernel. I added a draft PR (#2405 ) with three test cases:

  1. GroupNormOriginal: segment into 3 kernels
  2. GroupNormReshapeMovedToInputOutputNoWeightBias: one kernel
  3. GroupNormReshapeMovedToInputOutput: one kernel

So the proposed plan is a 2-steps approach:

liqiangxl commented 3 months ago

Add original related issue. https://github.com/Lightning-AI/lightning-thunder/issues/468

wujingyue commented 3 months ago

Unfortunately, the current alias analysis wasn't built for tracking aliases involving intermediates. When I wrote it, I didn't see a strong case for dealing with intermediates and expected intermediates to be fused into a kernel and become cheap index calculation. However, in this case, it appears that the two reshapes have caused the normalization scheduler to bail out...

2405 sounds like a reasonable extension -- add segment_set after input reshapes and before output reshapes. Caveat: there are some tricky patterns that we'll need to consider, e.g., if the original input is used by an operation other than the reshape, we probably shouldn't segment out the reshape. But we can leave these important implementation details to PR discussion.

That being said, is the normalization scheduler supposed to handle the two reshapes? That feels the right solution to me. However, if it would take a long time, I would totally understand the need for a quicker workaround.

wujingyue commented 3 months ago

Thanks for the write-up, @liqiangxl! The problem looks quite clear to me even though I was unsure about the solution.

liqiangxl commented 3 months ago

That being said, is the normalization scheduler supposed to handle the two reshapes? That feels the right solution to me. However, if it would take a long time, I would totally understand the need for a quicker workaround.

It should be able to handle the first reshape (which only includes split of ID), but for the second reshape (which is merge of an iter ID and a reduction ID) needs a lot of work. So another option is we can extend current reduction/normalization scheduler to handle some kinds of reshapes so the pre-segment optimizaiton pass only needs to process some specfic types of reshape e.g. reshape includes a merge of iter ID and redu ID.

liqiangxl commented 3 months ago

@jjsjann123 , @wujingyue , and @naoyam Thanks for the helpful discussions. Here is a summary of the approach :

naoyam commented 3 months ago

Looks like thunder did some optimizations of the captured graph. Specifically, it does reshape of input before cast to fp32. Afther computation, it added another reshape before cast back to fp16. The 2nd reshape caused the error.

def augmented_forward_fn(x):
  # x: "cuda:0 f16[2, 128, 16, 16]"
  t0 = torch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
    # t0 = ltorch.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
      # t0 = prims.reshape(x, (2, 32, 4, 16, 16))  # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
  [t16, t5, t9] = nvFusion0(t0)
    # t1 = prims.convert_element_type(t0, dtypes.float32)  # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
    # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
    # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1])  # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1])  # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t8 = prims.add(t6, 1e-05)  # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
    # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t12 = prims.sub(t1, t10)  # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4))  # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t14 = prims.mul(t12, t13)  # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
    # t15 = prims.reshape(t14, (2, 128, 16, 16))  # t15: "cuda:0 f32[2, 128, 16, 16]"
    # t16 = prims.convert_element_type(t15, dtypes.float16)  # t16: "cuda:0 f16[2, 128, 16, 16]"
  return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t0, t5, t9), ())

Trying to catch up what's going on here, but still confused why these two are not fused:

(1) pointwise doing cast + reshape (2) normalization

Doesn't the second reshape belong to the third segment?

(3) pointwise doing reshape + scale & bias

wujingyue commented 3 months ago

Here is a summary of the approach

LGTM. I'll defer to others whether we should fix the scheduler(s) to accept both types of reshapes. That sounds like the right fix to me even though it may take longer.

liqiangxl commented 3 months ago

Trying to catch up what's going on here, but still confused why these two are not fused:

(1) pointwise doing cast + reshape (2) normalization

Rejected by reductionInterferingView(), it divides the IDs of the reduction tv into multiple groups based on iter or reduction. Then generate a disjoint set for all the IDs and check if there are IDs in the same entry of the disjoint set belongs to two different groups.

For example, a tv with [{i0}, {32}, {i2/32}, {i3}, {i4}], bold represents reduction dims.

I think we can safely skip these checks if the reshape only involves split since it won't cause the merge of iter ID and redu ID which is the root limitation of the current reduction scheudler. The fix is simple and avoids the complex checks. Another approach is revise the logic in the disjoint set based checks. It seems more complex than the current appraoch in #2437. so this opion is not explored.

Doesn't the second reshape belong to the third segment?

(3) pointwise doing reshape + scale & bias

Yes.

wujingyue commented 3 months ago

thunder will ensure reshape is moved to the front and end of the fusion

I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.

liqiangxl commented 3 months ago

thunder will ensure reshape is moved to the front and end of the fusion

I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.

Thanks for the quick feedback. I'll check with other stakeholders and may schedule a short sync meeting.

naoyam commented 3 months ago

Trying to catch up what's going on here, but still confused why these two are not fused: (1) pointwise doing cast + reshape (2) normalization

Rejected by reductionInterferingView(), it divides the IDs of the reduction tv into multiple groups based on iter or reduction. Then generate a disjoint set for all the IDs and check if there are IDs in the same entry of the disjoint set belongs to two different groups.

For example, a tv with [{i0}, {32}, {i2/32}, {i3}, {i4}], bold represents reduction dims.

  • (1) They are grouped into 2 groups **{i2/32}, {i3}, {i4}** and {i0}, {32}.
  • (2) **{i2/32}**, {32}, {i2} are in the same entry of the disjoint sets (why?)
  • (3) **{i2/32}** is in a different group from **{32}**
  • (4)reductionInterferingView() returns false and fusion rejected.

I think we can safely skip these checks if the reshape only involves split since it won't cause the merge of iter ID and redu ID which is the root limitation of the current reduction scheudler. The fix is simple and avoids the complex checks. Another approach is revise the logic in the disjoint set based checks. It seems more complex than the current appraoch in #2437. so this opion is not explored.

Doesn't the second reshape belong to the third segment? (3) pointwise doing reshape + scale & bias

Yes.

I see. I was only thinking about requiresForwardViewReplay, but it now makes sense.

Have you thought about using IdModel to rewrite reductionInterferingView? I think it'd be relatively straightforward with an ID graph. All we need to see is if any use of the ID groups of the reference tensor could merge multiple groups, which would be a simple graph traversal. It should naturally handle Merge as well.

jjsjann123 commented 3 months ago

thunder will ensure reshape is moved to the front and end of the fusion

I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.

Thanks for the quick feedback. I'll check with other stakeholders and may schedule a short sync meeting.

Agree that the proper approach is that nvfuser should be able to handle graph level optimization and re-order some trivial reshape to allow better fusion. We can discuss the priority on that.

Meanwhile, the last reshape in thunder might be an easier thing to change with some slightly ugly code: https://github.com/Lightning-AI/lightning-thunder/blob/14e6c9b67eb038ab28a192cd381bc183b77e8f81/thunder/torch/__init__.py#L4406-L4420