Open liqiangxl opened 3 months ago
I wonder why (1)
and (2)
are not fused. I think an easy solution for the second reshape would move it to the end of the fusion and turn it to a meta-data operation.
Due to reductionInterferingView
check. If I disable this check & remove the last reshape in group norm (no weight & bias), nvFuser uses Inner Persistent
but got err msg Merging IterDomains requires that their iteration types match. Outer: iS133{32}, Inner: rS17{i2}
.
We can also move the 1st reshape to the beginning of the fusion, so it is just a no-op. what do you think?
Due to
reductionInterferingView
check. If I disable this check & remove the last reshape in group norm (no weight & bias), nvFuser usesInner Persistent
but got err msgMerging IterDomains requires that their iteration types match. Outer: iS133{32}, Inner: rS17{i2}
. We can also move the 1st reshape to the beginning of the fusion, so it is just a no-op. what do you think?
There's only one normalization in (1)
and (2)
, so I assume x1.sum
is used as the reference for the segment. Which tensor does this conflict come from?
Looks like thunder did some optimizations of the captured graph. Specifically, it does reshape
of input before cast to fp32
. Afther computation, it added another reshape before cast back to fp16. The 2nd reshape caused the error.
def augmented_forward_fn(x):
# x: "cuda:0 f16[2, 128, 16, 16]"
t0 = torch.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
# t0 = ltorch.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
# t0 = prims.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
[t16, t5, t9] = nvFusion0(t0)
# t1 = prims.convert_element_type(t0, dtypes.float32) # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
# (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
# t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1]) # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
# t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1]) # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
# t8 = prims.add(t6, 1e-05) # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
# t9 = prims.rsqrt(t8) # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
# t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4)) # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
# t12 = prims.sub(t1, t10) # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
# t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4)) # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
# t14 = prims.mul(t12, t13) # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
# t15 = prims.reshape(t14, (2, 128, 16, 16)) # t15: "cuda:0 f32[2, 128, 16, 16]"
# t16 = prims.convert_element_type(t15, dtypes.float16) # t16: "cuda:0 f16[2, 128, 16, 16]"
return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t0, t5, t9), ())
Thanks for the suggestion @jjsjann123 and @kevinstephano, nvFuser gets two reshapes if thunder.jit(torch_group_norm, nv_enable_bookend=False)
def augmented_forward_fn(x):
# x: "cuda:0 f16[2, 128, 16, 16]"
[t16, t5, t9] = nvFusion0(x)
# t0 = prims.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]"
# t1 = prims.convert_element_type(t0, dtypes.float32) # t1: "cuda:0 f32[2, 32, 4, 16, 16]"
# (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0)
# t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1]) # t6: "cuda:0 f32[2, 32, 1, 1, 1]"
# t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1]) # t7: "cuda:0 f32[2, 32, 1, 1, 1]"
# t8 = prims.add(t6, 1e-05) # t8: "cuda:0 f32[2, 32, 1, 1, 1]"
# t9 = prims.rsqrt(t8) # t9: "cuda:0 f32[2, 32, 1, 1, 1]"
# t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4)) # t10: "cuda:0 f32[2, 32, 4, 16, 16]"
# t12 = prims.sub(t1, t10) # t12: "cuda:0 f32[2, 32, 4, 16, 16]"
# t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4)) # t13: "cuda:0 f32[2, 32, 4, 16, 16]"
# t14 = prims.mul(t12, t13) # t14: "cuda:0 f32[2, 32, 4, 16, 16]"
# t15 = prims.reshape(t14, (2, 128, 16, 16)) # t15: "cuda:0 f32[2, 128, 16, 16]"
# t16 = prims.convert_element_type(t15, dtypes.float16) # t16: "cuda:0 f16[2, 128, 16, 16]"
return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t5, t9, x), (0,))
If we change these two reshapes to reshape of input
and reshape before output
, and revise MarkAliasesPreparePass
, nvFuser can change from 3 kernels to two no ops and a normalization kernel. I added a draft PR (#2405 ) with three test cases:
GroupNormOriginal
: segment into 3 kernelsGroupNormReshapeMovedToInputOutputNoWeightBias
: one kernelGroupNormReshapeMovedToInputOutput
: one kernelSo the proposed plan is a 2-steps approach:
MarkAliasesPreparePass
to get one kernel if the fusion only have reshape of input
and reshape before output
. PR #2405. Hi @wujingyue do you have any comments about this plan? Thanks!reshape of input
or reshape before output
Add original related issue. https://github.com/Lightning-AI/lightning-thunder/issues/468
Unfortunately, the current alias analysis wasn't built for tracking aliases involving intermediates. When I wrote it, I didn't see a strong case for dealing with intermediates and expected intermediates to be fused into a kernel and become cheap index calculation. However, in this case, it appears that the two reshapes have caused the normalization scheduler to bail out...
That being said, is the normalization scheduler supposed to handle the two reshapes? That feels the right solution to me. However, if it would take a long time, I would totally understand the need for a quicker workaround.
Thanks for the write-up, @liqiangxl! The problem looks quite clear to me even though I was unsure about the solution.
That being said, is the normalization scheduler supposed to handle the two reshapes? That feels the right solution to me. However, if it would take a long time, I would totally understand the need for a quicker workaround.
It should be able to handle the first reshape (which only includes split of ID), but for the second reshape (which is merge of an iter ID and a reduction ID) needs a lot of work. So another option is we can extend current reduction/normalization scheduler to handle some kinds of reshapes so the pre-segment optimizaiton pass only needs to process some specfic types of reshape e.g. reshape includes a merge of iter ID and redu ID.
@jjsjann123 , @wujingyue , and @naoyam Thanks for the helpful discussions. Here is a summary of the approach :
split
of domains, it won't interfere with reduction where iter and redu domains are merged separately. #2437 allows reduction/normalization schedulers take the 1st reshape.merge
of iter
and redu
domains, #2405 peels off the 2nd reshape into a no-op.Looks like thunder did some optimizations of the captured graph. Specifically, it does
reshape
of input beforecast to fp32
. Afther computation, it added another reshape before cast back to fp16. The 2nd reshape caused the error.def augmented_forward_fn(x): # x: "cuda:0 f16[2, 128, 16, 16]" t0 = torch.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]" # t0 = ltorch.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]" # t0 = prims.reshape(x, (2, 32, 4, 16, 16)) # t0: "cuda:0 f16[2, 32, 4, 16, 16]" [t16, t5, t9] = nvFusion0(t0) # t1 = prims.convert_element_type(t0, dtypes.float32) # t1: "cuda:0 f32[2, 32, 4, 16, 16]" # (t4, t5) = prims.var_mean(t1, (2, 3, 4), correction=0) # t6 = prims.broadcast_in_dim(t4, [2, 32, 1, 1, 1], [0, 1]) # t6: "cuda:0 f32[2, 32, 1, 1, 1]" # t7 = prims.broadcast_in_dim(t5, [2, 32, 1, 1, 1], [0, 1]) # t7: "cuda:0 f32[2, 32, 1, 1, 1]" # t8 = prims.add(t6, 1e-05) # t8: "cuda:0 f32[2, 32, 1, 1, 1]" # t9 = prims.rsqrt(t8) # t9: "cuda:0 f32[2, 32, 1, 1, 1]" # t10 = prims.broadcast_in_dim(t7, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4)) # t10: "cuda:0 f32[2, 32, 4, 16, 16]" # t12 = prims.sub(t1, t10) # t12: "cuda:0 f32[2, 32, 4, 16, 16]" # t13 = prims.broadcast_in_dim(t9, (2, 32, 4, 16, 16), (0, 1, 2, 3, 4)) # t13: "cuda:0 f32[2, 32, 4, 16, 16]" # t14 = prims.mul(t12, t13) # t14: "cuda:0 f32[2, 32, 4, 16, 16]" # t15 = prims.reshape(t14, (2, 128, 16, 16)) # t15: "cuda:0 f32[2, 128, 16, 16]" # t16 = prims.convert_element_type(t15, dtypes.float16) # t16: "cuda:0 f16[2, 128, 16, 16]" return {'output': t16, 'flat_args': [x], 'flat_output': (t16,)}, ((t0, t5, t9), ())
Trying to catch up what's going on here, but still confused why these two are not fused:
(1) pointwise doing cast + reshape (2) normalization
Doesn't the second reshape belong to the third segment?
(3) pointwise doing reshape + scale & bias
Here is a summary of the approach
LGTM. I'll defer to others whether we should fix the scheduler(s) to accept both types of reshapes. That sounds like the right fix to me even though it may take longer.
Trying to catch up what's going on here, but still confused why these two are not fused:
(1) pointwise doing cast + reshape (2) normalization
Rejected by reductionInterferingView()
, it divides the IDs of the reduction tv into multiple groups based on iter or reduction. Then generate a disjoint set for all the IDs and check if there are IDs in the same entry of the disjoint set belongs to two different groups.
For example, a tv with [{i0}, {32}, {i2/32}, {i3}, {i4}], bold represents reduction dims.
**{i2/32}, {i3}, {i4}**
and {i0}, {32}
. **{i2/32}**, {32}, {i2}
are in the same entry of the disjoint sets (why?)**{i2/32}**
is in a different group from **{32}**
reductionInterferingView()
returns false and fusion rejected.I think we can safely skip these checks if the reshape only involves split
since it won't cause the merge of iter
ID and redu
ID which is the root limitation of the current reduction scheudler. The fix is simple and avoids the complex checks.
Another approach is revise the logic in the disjoint set based checks. It seems more complex than the current appraoch in #2437. so this opion is not explored.
Doesn't the second reshape belong to the third segment?
(3) pointwise doing reshape + scale & bias
Yes.
thunder will ensure reshape is moved to the front and end of the fusion
I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.
thunder will ensure reshape is moved to the front and end of the fusion
I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.
Thanks for the quick feedback. I'll check with other stakeholders and may schedule a short sync meeting.
Trying to catch up what's going on here, but still confused why these two are not fused: (1) pointwise doing cast + reshape (2) normalization
Rejected by
reductionInterferingView()
, it divides the IDs of the reduction tv into multiple groups based on iter or reduction. Then generate a disjoint set for all the IDs and check if there are IDs in the same entry of the disjoint set belongs to two different groups.For example, a tv with [{i0}, {32}, {i2/32}, {i3}, {i4}], bold represents reduction dims.
- (1) They are grouped into 2 groups
**{i2/32}, {i3}, {i4}**
and{i0}, {32}
.- (2)
**{i2/32}**, {32}, {i2}
are in the same entry of the disjoint sets (why?)- (3)
**{i2/32}**
is in a different group from**{32}**
- (4)
reductionInterferingView()
returns false and fusion rejected.I think we can safely skip these checks if the reshape only involves
split
since it won't cause the merge ofiter
ID andredu
ID which is the root limitation of the current reduction scheudler. The fix is simple and avoids the complex checks. Another approach is revise the logic in the disjoint set based checks. It seems more complex than the current appraoch in #2437. so this opion is not explored.Doesn't the second reshape belong to the third segment? (3) pointwise doing reshape + scale & bias
Yes.
I see. I was only thinking about requiresForwardViewReplay
, but it now makes sense.
Have you thought about using IdModel to rewrite reductionInterferingView
? I think it'd be relatively straightforward with an ID graph. All we need to see is if any use of the ID groups of the reference tensor could merge multiple groups, which would be a simple graph traversal. It should naturally handle Merge as well.
thunder will ensure reshape is moved to the front and end of the fusion
I missed this part. I was expecting nvFuser not Thunder to do this. It sounds like a specific code motion to work around an nvFuser limitation. Doing this in nvFuser makes nvFuser standalone, not relying on a specific code pattern upstream.
Thanks for the quick feedback. I'll check with other stakeholders and may schedule a short sync meeting.
Agree that the proper approach is that nvfuser should be able to handle graph level optimization and re-order some trivial reshape to allow better fusion. We can discuss the priority on that.
Meanwhile, the last reshape in thunder might be an easier thing to change with some slightly ugly code: https://github.com/Lightning-AI/lightning-thunder/blob/14e6c9b67eb038ab28a192cd381bc183b77e8f81/thunder/torch/__init__.py#L4406-L4420
Group norm is calculated as:
Due to the two reshapes, normalization scheduler rejects the unsegmented fusion and then it is segmented into three sub-fusions: (1) pointwise doing
cast + reshape
(2) normalization (3) pointwise doingreshape + scale & bias
Reproduce (modified from apex group norm implementation and apex group norm test ):