Open naoyam opened 1 day ago
@zasdfgbnm @jacobhinkle Please let me know if this makes sense. I have a WAR (#3454), which isn't ideal at all but works for RoPE.
It seems like this is specific to Resize. For example if we resize by padding by a
on the left and -a
on the right, we could use the result in a binaryOp with the original input and wind up with the same situation.
Is this the only pattern we know of that displays this issue?
Maybe it's enough to just not exact map if the two groups already have a Resize ExprGroup between them? This would mean we'd have to have already processed all the Resize Exprs before building the exact graph. If we did that then the Exact graph would become finer than it is today; I'm not sure whether that would cause problems or not but I can imagine we might assume that Iteration input IDs to BinaryOp and TernaryOp are exact mapped in some places.
It seems like this is specific to Resize. For example if we resize by padding by
a
on the left and-a
on the right, we could use the result in a binaryOp with the original input and wind up with the same situation.
I think this indexing issue is specific to Resize, or more specifically the rotation. We map the rotated domain with its input, which of course have the same extent but for the sake of indexing they should not be considered the same.
As discussed in #3072, reshape can also result in a cyclic graph, but I don't think it would cause an indexing problem like this.
Is this the only pattern we know of that displays this issue?
Yes.
Maybe it's enough to just not exact map if the two groups already have a Resize ExprGroup between them? This would mean we'd have to have already processed all the Resize Exprs before building the exact graph. If we did that then the Exact graph would become finer than it is today; I'm not sure whether that would cause problems or not but I can imagine we might assume that Iteration input IDs to BinaryOp and TernaryOp are exact mapped in some places.
Are you saying we shouldn't map i0
and i6
? Would that mean i3
should also not be mapped with i6
? I'm pretty sure that would cause a different problem with inlining and sync analysis since the exact graph is currently used as the basis of all the other graphs.
You're right, the simple approach is not sufficient.
Maybe we need another type of ExprGroup in this case that indicates multiple ValGroups are aligned in an expression such that they would normally be Exact mapped but cannot be, for example in this case, because they would introduce cycles. That would let us take different paths to the ValGroups of i3 and i0 for indexing. For determining if they have the same extents we can derive an "Extents" graph by mapping all the input and output ValGroups of these "Align" ExprGroups. That Extents group might be a substitute for Exact that would work for inlining and other analysis but it could have cycles.
Fundamentally, this seems to indicate
i0
should not be mapped with{i4, i5, i3, i6}
.
Why it is "i0
should not be mapped with {i4, i5, i3, i6}
", instead of the following?
{i0, i6}
should not be mapped with {i4, i5, i3}
{i0, i6, i3}
should not be mapped with {i4, i5}
When we are indexing t0 as a direct producer of t4, we do want i0 to be mapped to i6. Depending on the task we are doing, we may or may not want things to be mapped.
I think the fundamental problem is not whether i0 should be mapped with i6 or not. The fundamental problem is we should not traverse the exact graph for indexing. Indexing needs a real index graph, where two items are mapped if and only if the have the same index, as originally planned by Christian. This means, a tensor as a consumer, as a producer of different ops, because they have different indices, they are not mapped.
Oh, my comment was not that precise. I agree with you.
This means, a tensor as a consumer, as a producer of different ops, because they have different indices, they are not mapped.
Just so I understand this, you mean that this graph would contain one (1+num_uses) copies of each ID in tv->allIDs()
(or some subset if possible). Then we would do mapping by looking at each TV Expr and mapping the producer logical IDs' copy for that Expr with the consumer copy of the corresponding output TVs maybeRoot ID. This way the ValGraph doesn't have paths that propagate beyond one TV expression at a time, but maybe that's the point? It feels like that means we would be creating a single graph per producer/consumer TV pair.
Just so I understand this, you mean that this graph would contain one (1+num_uses) copies of each ID in
tv->allIDs()
(or some subset if possible). Then we would do mapping by looking at each TV Expr and mapping the producer logical IDs' copy for that Expr with the consumer copy of the corresponding output TVs maybeRoot ID.
I think that's pretty much it.
It feels like that means we would be creating a single graph per producer/consumer TV pair.
Partially yes, but it is important to note that:
ValGraph
to manage all these relations. We can use (IterDomain, TV expr)
as the key for ValGraph
.For example, if we have
T2[b1, I1] = set(T0[b0, I0]) ca_pos(1)
T3[I2, I3] = add(T2[b1, I1], T1[I4, I5])
Then we will have groups like:
g0: {(b0, set), (b1, set), (b1, add)}
g1: {(I0, set), (I1, set)}
g2: {(I1, add), (I5, add), (I3, add)}
g3: {(I2, add), (I4, add)}
As discussed in #3072, a common pattern in RoPE results in a cyclic exact graph. For example, when a domain is split to two halves, rotated and concatenated, if the resulting domain is also used with the initial input domain, the initial domain and the final domain are mapped, and there are the resize expressions for the slice and concat, yielding a cycle in the exact graph.
This may also cause an invalid indexing path when scheduled in a certain way. For example, we could schedule a rotation pattern as illustrated below:
t0
is the input tensor with only one iter domain,i0
. The math should look like:The scheduling illustrated here is consumer-based scheduling, meaning the loop domain of the final output,
t3
, is propagated back tot1
andt2
. They are scheduled by additionalresize
such that their loop domains match with the logical domain oft3
. In this way, since all loop domains are exactly mapped, this fusion can be freely parallelized without any synchronization. This scheduling approach is what I'm aiming to have as a first version of the scheduler for RoPE.However, if this fusion also has a residual path like shown below (i.e.,
t4 = t0 + t3
), indexingt0
as the producer oft1
ort2
doesn't work due to the exact mapping ofi0
andi4
(andi5
).When indexing
t0
fort1
, its loop domain is justi4
, which is mapped withi0
. In the new indexing method, this simply means the loop index ofi4
can be just used as is for indexingt0
. Since they are grouped together, there's no indexing traversal involved.However, this is clearly incorrect. For
t1
, we need the left half oft0
to be placed at the right half oft1
. So, the index math should look likei - N/2
, wherei
is the loop index ofi4
. Similarly, fort2
, it should bei + N/2
.Fundamentally, this seems to indicate
i0
should not be mapped with{i4, i5, i3, i6}
. However, it's automatically done in our current formulation of the exact graph because of thet4
expression,t4 = t0 + t3
. While they do have the same extent, they may need to be considered unmapped for indexing.