NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Invalid indexing path when resize is used with a residual path #3455

Open naoyam opened 1 day ago

naoyam commented 1 day ago

As discussed in #3072, a common pattern in RoPE results in a cyclic exact graph. For example, when a domain is split to two halves, rotated and concatenated, if the resulting domain is also used with the initial input domain, the initial domain and the final domain are mapped, and there are the resize expressions for the slice and concat, yielding a cycle in the exact graph.

This may also cause an invalid indexing path when scheduled in a certain way. For example, we could schedule a rotation pattern as illustrated below:

Image

t0 is the input tensor with only one iter domain, i0. The math should look like:

t1 = t0[:N/2]
t2 = t0[N/2:]
t3 = cat(t2, t1)

The scheduling illustrated here is consumer-based scheduling, meaning the loop domain of the final output, t3, is propagated back to t1 and t2. They are scheduled by additional resize such that their loop domains match with the logical domain of t3. In this way, since all loop domains are exactly mapped, this fusion can be freely parallelized without any synchronization. This scheduling approach is what I'm aiming to have as a first version of the scheduler for RoPE.

However, if this fusion also has a residual path like shown below (i.e., t4 = t0 + t3), indexing t0 as the producer of t1 or t2 doesn't work due to the exact mapping of i0 and i4 (and i5).

Image

When indexing t0 for t1, its loop domain is just i4, which is mapped with i0. In the new indexing method, this simply means the loop index of i4 can be just used as is for indexing t0. Since they are grouped together, there's no indexing traversal involved.

However, this is clearly incorrect. For t1, we need the left half of t0 to be placed at the right half of t1. So, the index math should look like i - N/2, where i is the loop index of i4. Similarly, for t2, it should be i + N/2.

Fundamentally, this seems to indicate i0 should not be mapped with {i4, i5, i3, i6}. However, it's automatically done in our current formulation of the exact graph because of the t4 expression, t4 = t0 + t3. While they do have the same extent, they may need to be considered unmapped for indexing.

naoyam commented 1 day ago

@zasdfgbnm @jacobhinkle Please let me know if this makes sense. I have a WAR (#3454), which isn't ideal at all but works for RoPE.

jacobhinkle commented 1 day ago

It seems like this is specific to Resize. For example if we resize by padding by a on the left and -a on the right, we could use the result in a binaryOp with the original input and wind up with the same situation.

Is this the only pattern we know of that displays this issue?

Maybe it's enough to just not exact map if the two groups already have a Resize ExprGroup between them? This would mean we'd have to have already processed all the Resize Exprs before building the exact graph. If we did that then the Exact graph would become finer than it is today; I'm not sure whether that would cause problems or not but I can imagine we might assume that Iteration input IDs to BinaryOp and TernaryOp are exact mapped in some places.

naoyam commented 1 day ago

It seems like this is specific to Resize. For example if we resize by padding by a on the left and -a on the right, we could use the result in a binaryOp with the original input and wind up with the same situation.

I think this indexing issue is specific to Resize, or more specifically the rotation. We map the rotated domain with its input, which of course have the same extent but for the sake of indexing they should not be considered the same.

As discussed in #3072, reshape can also result in a cyclic graph, but I don't think it would cause an indexing problem like this.

Is this the only pattern we know of that displays this issue?

Yes.

Maybe it's enough to just not exact map if the two groups already have a Resize ExprGroup between them? This would mean we'd have to have already processed all the Resize Exprs before building the exact graph. If we did that then the Exact graph would become finer than it is today; I'm not sure whether that would cause problems or not but I can imagine we might assume that Iteration input IDs to BinaryOp and TernaryOp are exact mapped in some places.

Are you saying we shouldn't map i0 and i6? Would that mean i3 should also not be mapped with i6? I'm pretty sure that would cause a different problem with inlining and sync analysis since the exact graph is currently used as the basis of all the other graphs.

jacobhinkle commented 1 day ago

You're right, the simple approach is not sufficient.

Maybe we need another type of ExprGroup in this case that indicates multiple ValGroups are aligned in an expression such that they would normally be Exact mapped but cannot be, for example in this case, because they would introduce cycles. That would let us take different paths to the ValGroups of i3 and i0 for indexing. For determining if they have the same extents we can derive an "Extents" graph by mapping all the input and output ValGroups of these "Align" ExprGroups. That Extents group might be a substitute for Exact that would work for inlining and other analysis but it could have cycles.

zasdfgbnm commented 5 hours ago

Fundamentally, this seems to indicate i0 should not be mapped with {i4, i5, i3, i6}.

Why it is "i0 should not be mapped with {i4, i5, i3, i6}", instead of the following?

When we are indexing t0 as a direct producer of t4, we do want i0 to be mapped to i6. Depending on the task we are doing, we may or may not want things to be mapped.

I think the fundamental problem is not whether i0 should be mapped with i6 or not. The fundamental problem is we should not traverse the exact graph for indexing. Indexing needs a real index graph, where two items are mapped if and only if the have the same index, as originally planned by Christian. This means, a tensor as a consumer, as a producer of different ops, because they have different indices, they are not mapped.

naoyam commented 4 hours ago

Oh, my comment was not that precise. I agree with you.

jacobhinkle commented 2 hours ago

This means, a tensor as a consumer, as a producer of different ops, because they have different indices, they are not mapped.

Just so I understand this, you mean that this graph would contain one (1+num_uses) copies of each ID in tv->allIDs() (or some subset if possible). Then we would do mapping by looking at each TV Expr and mapping the producer logical IDs' copy for that Expr with the consumer copy of the corresponding output TVs maybeRoot ID. This way the ValGraph doesn't have paths that propagate beyond one TV expression at a time, but maybe that's the point? It feels like that means we would be creating a single graph per producer/consumer TV pair.

zasdfgbnm commented 1 hour ago

Just so I understand this, you mean that this graph would contain one (1+num_uses) copies of each ID in tv->allIDs() (or some subset if possible). Then we would do mapping by looking at each TV Expr and mapping the producer logical IDs' copy for that Expr with the consumer copy of the corresponding output TVs maybeRoot ID.

I think that's pretty much it.

It feels like that means we would be creating a single graph per producer/consumer TV pair.

Partially yes, but it is important to note that:

  1. We should still use a single ValGraph to manage all these relations. We can use (IterDomain, TV expr) as the key for ValGraph.
  2. We should map IterDomain across TV exprs based on inlining.

For example, if we have

T2[b1, I1] = set(T0[b0, I0]) ca_pos(1)
T3[I2, I3] = add(T2[b1, I1], T1[I4, I5])

Then we will have groups like:

g0: {(b0, set), (b1, set), (b1, add)}
g1: {(I0, set), (I1, set)}
g2: {(I1, add), (I5, add), (I3, add)}
g3: {(I2, add), (I4, add)}