Open gspschmid opened 2 months ago
cc @ezhulenev @frgossen @mattjj @nouiz
Is this expected behavior or a bug in bufferization/copying? I haven't investigated, but perhaps XLA-native async ops rely on special treatment during live range computation (
Today, I would say this is expected behaviour in XLA. Async ops do not get special treatment yet but they do require that the buffer passed from async start to async done does not have any other uses. That way, you can rely on there not being a copy.
The pass you'd want to look at is copy insertion. Especially around while and cond ops it inserts copies and then tries to remove them, which it will only do if it can proof no overlapping live time.
Does that mean that XLA cannot take an async op in a conditional branch/loop and move the async-done outside the branch/loop body? Or is that it can, but possibly incorrectly so (given the chance of copies)?
Async ops do not get special treatment yet but they do require that the buffer passed from async start to async done does not have any other uses. That way, you can rely on there not being a copy.
That makes me wonder whether there is actually any JAX-surface-level (or at least StableHLO) pattern that would guarantee a buffer being kept alive. In the presence of earlier program transformations and, say, host offloading, even relying on non-copying for straightline code seems like tenuous assumption.
Does that mean that XLA cannot take the an async op in a conditional branch/loop and move the async-done outside the branch/loop body? Or is that it does, but possibly incorrectly so (given the possibility of copies)?
It does not support that today but I'm working on it for loops.
that would guarantee a buffer being kept alive
I don't think there is atm. other than no loops/conds + single use chains. Is that what you mean with "straightline code"?
I think we need a special treatment for async ops, we should be able to prove with dataflow analysis that async start is actually consumed by a corresponding done operation, and nothing in between. I suspect same issue will show up in pipeline partitioning and reordering send/recv start and done.
I think we need a special treatment for async ops, we should be able to prove with dataflow analysis that async start is actually consumed by a corresponding done operation, and nothing in between. I suspect same issue will show up in pipeline partitioning and reordering send/recv start and done.
Yes, this shows up in pipeline parallelism. This is what I want to fix it for in loops
Why focus on async ops? Why not focus on the aliasing behavior? What is the optimized HLO to see if and where the copy happens.
I think the problem is that today you can’t distinguish between mutable and immutable aliasing, and by default XLA assumes in place update when it sees aliasing buffers
Async operations require a mechanism to keep buffers alive from
op-start
until the correspondingop-done
. A natural way to express the intent to keep, say, the input bufferx
tosend-start(x)
alive is to makesend-start
"pass the input buffer through" by instructing XLA to alias its first input with its first output, and then passing the aliased output tosend-done
. Async HLO Instructions in fact leverages this very idea. In pseudo code,Can we actually rely on this pattern working even in the presence of arbitrary aliasing operations being applied to
x'
before it eventually flows intofoo-done
?I've been experimenting with custom calls that mimic this pattern (but are not known to be asynchronous to XLA) and encountered a situation where this assumption falls down. Consider the following example
which produces
Notably, in the case of
f_bad
we deal with two distinct buffers, i.e. the input is implicitly copied at some point, breaking our initial assumption.Is this expected behavior or a bug in bufferization/copying? I haven't investigated, but perhaps XLA-native async ops rely on special treatment during live range computation (https://github.com/openxla/xla/blob/6c49eab6ffeb5d1f22f9daac5378a5433335442a/xla/hlo/utils/hlo_live_range.cc#L215)?
Reproducer: https://gist.github.com/gspschmid/372bba804b48c4abbf5c94f19b2b32cd HLO
f_good
(before optimizations) HLOf_bad
(before optimizations)