openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.69k stars 432 forks source link

Buffer aliasing behavior in the presence of control flow #16793

Open gspschmid opened 2 months ago

gspschmid commented 2 months ago

Async operations require a mechanism to keep buffers alive from op-start until the corresponding op-done. A natural way to express the intent to keep, say, the input buffer x to send-start(x) alive is to make send-start "pass the input buffer through" by instructing XLA to alias its first input with its first output, and then passing the aliased output to send-done. Async HLO Instructions in fact leverages this very idea. In pseudo code,

x' = send-start(x)  # send-start aliases input 0 with output 0
# other ops
y = send-done(x')  # original x should be alive until now, allowing send to keep using the buffer

Can we actually rely on this pattern working even in the presence of arbitrary aliasing operations being applied to x' before it eventually flows into foo-done?

I've been experimenting with custom calls that mimic this pattern (but are not known to be asynchronous to XLA) and encountered a situation where this assumption falls down. Consider the following example

def show(x):
  # a primitive that aliases x and prints the address of x
  (...)

@jax.jit
def f_good(x):
  x = show(x)
  y = jax.lax.cond(x != 999, (lambda: x), (lambda: x))
  return show(y)

@jax.jit
def f_bad(x):
  x = show(x)
  y = jax.lax.cond(x != 999, (lambda: x), (lambda: jnp.zeros_like(x)))
  return show(y)

def example():
  print('f_good:')
  f_good(123)

  print('\nf_bad:')
  f_bad(123)

which produces

f_good:
SHOW in=0x7f294a000100
SHOW in=0x7f294a000100

f_bad:
SHOW in=0x7f294a000280
SHOW in=0x7f294a000100

Notably, in the case of f_bad we deal with two distinct buffers, i.e. the input is implicitly copied at some point, breaking our initial assumption.

Is this expected behavior or a bug in bufferization/copying? I haven't investigated, but perhaps XLA-native async ops rely on special treatment during live range computation (https://github.com/openxla/xla/blob/6c49eab6ffeb5d1f22f9daac5378a5433335442a/xla/hlo/utils/hlo_live_range.cc#L215)?

Reproducer: https://gist.github.com/gspschmid/372bba804b48c4abbf5c94f19b2b32cd HLO f_good (before optimizations) HLO f_bad (before optimizations)

gspschmid commented 2 months ago

cc @ezhulenev @frgossen @mattjj @nouiz

frgossen commented 2 months ago

Is this expected behavior or a bug in bufferization/copying? I haven't investigated, but perhaps XLA-native async ops rely on special treatment during live range computation (

Today, I would say this is expected behaviour in XLA. Async ops do not get special treatment yet but they do require that the buffer passed from async start to async done does not have any other uses. That way, you can rely on there not being a copy.

The pass you'd want to look at is copy insertion. Especially around while and cond ops it inserts copies and then tries to remove them, which it will only do if it can proof no overlapping live time.

gspschmid commented 2 months ago

Does that mean that XLA cannot take an async op in a conditional branch/loop and move the async-done outside the branch/loop body? Or is that it can, but possibly incorrectly so (given the chance of copies)?

Async ops do not get special treatment yet but they do require that the buffer passed from async start to async done does not have any other uses. That way, you can rely on there not being a copy.

That makes me wonder whether there is actually any JAX-surface-level (or at least StableHLO) pattern that would guarantee a buffer being kept alive. In the presence of earlier program transformations and, say, host offloading, even relying on non-copying for straightline code seems like tenuous assumption.

frgossen commented 2 months ago

Does that mean that XLA cannot take the an async op in a conditional branch/loop and move the async-done outside the branch/loop body? Or is that it does, but possibly incorrectly so (given the possibility of copies)?

It does not support that today but I'm working on it for loops.

that would guarantee a buffer being kept alive

I don't think there is atm. other than no loops/conds + single use chains. Is that what you mean with "straightline code"?

ezhulenev commented 2 months ago

I think we need a special treatment for async ops, we should be able to prove with dataflow analysis that async start is actually consumed by a corresponding done operation, and nothing in between. I suspect same issue will show up in pipeline partitioning and reordering send/recv start and done.

frgossen commented 2 months ago

I think we need a special treatment for async ops, we should be able to prove with dataflow analysis that async start is actually consumed by a corresponding done operation, and nothing in between. I suspect same issue will show up in pipeline partitioning and reordering send/recv start and done.

Yes, this shows up in pipeline parallelism. This is what I want to fix it for in loops

nouiz commented 2 months ago

Why focus on async ops? Why not focus on the aliasing behavior? What is the optimized HLO to see if and where the copy happens.

ezhulenev commented 2 months ago

I think the problem is that today you can’t distinguish between mutable and immutable aliasing, and by default XLA assumes in place update when it sees aliasing buffers