google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.82k stars 2.73k forks source link

Guard against unintentional transfers to CPU #23215

Open kxs-dhenshall opened 3 weeks ago

kxs-dhenshall commented 3 weeks ago

In order to work around long compile times and hidden re-compilations, a bunch of JIT compiled logic can be broken into a sequence of smaller functions.

One of the implications of this is that the control flow logic has to be host-side, since using jax.lax control flow functions triggers results in compilation of the target condition/body functions which is will effectively just compiling the original function. For example, if you were to break a function f into two functions f1 and f2 your code may look something like this with the loop running on the host:

state = compute_initial_state()

for i in range(n) :
    state = fn1(state)
    state = fn2(state)

print(state.result)

Making this work efficiently requires a few extra considerations:

  1. Avoiding GPU-to-CPU transfer between calls to the JAX-compiled functions
  2. Avoid re-allocations as much as possible through aggressive donation of arguments
  3. Avoid hidden recompilation

This works but is fragile and it can be difficult to protect against regressions without burdensome performance testing. One thing that works great is that (3.) can be dealt with easily through AOT compilation, which will throw an exception whenever a regression is introduced that mistakenly changes the type of a return value (e.g. an array in state is initialized to [1] with a data type of int32 when it should have been initialized as [1.] with data type of float32).

For use cases like this, I think that it would be useful to have mechanisms that would protect against the following two cases:

  1. A GPU-to-CPU transfer mistakenly took place within the critical loop, because host-side logic accessed a GPU variable.
  2. An argument that should have been donated was not donated (e.g. could be introduced by changing of a method signature without updating the donation_argnames)

One approach would be to introducing new contexts that prevent these from happening inside a critical section of code. For example, with (2.) it could look something like the following (please ignore variable names, I did not put much thought into them):

state = compute_initial_state()
state = jax.device_put(state)

with jax.disable_transfer_to_cpu() :
    # Within this block, an exception will be thrown
    compiled_fn1 = compile(fn1, state)
    compiled_fn2 = compile(fn2, state)

    # This would throw an exception, since printing state.result would trigger
    # a transfer from GPU to CPU
    # print(state.result)

# This would not trigger an exception, since it is outside the protected block above
print(state.result)

Is this something that had been thought about or considered? I am not sure if this is an issue that other people have faced or not, and figured it was worth bringing up.

yashk2810 commented 3 weeks ago

For 1) you can use https://jax.readthedocs.io/en/latest/transfer_guard.html

For 2) You can add arr.is_deleted() check to make sure donation was successful? But note that in the future, we will delete the input array regardless of where donation was successful or not (which will help you?)

kxs-dhenshall commented 3 weeks ago

Thanks for the quick response, for (1) transfer_guard looks perfect.

For (2), I may do that but don't really want to keep around a bunch of checks. I have my own wrapper around the compiled function anyways, and can add the is_deleted check automatically so may just do that. I think in general that isn't the tool I was hoping for, I am going to dig around to see what kind of allocation stats I can find and if there are good allocation stats available I can use that to make sure that the memory allocated at the end of each iteration of the loop is not more than the memory allocated at the start.

kxs-dhenshall commented 1 week ago

For future context, I worked around the issue by using the suggestion and asserting that arr.is_deleted() after every call to the JIT'tted function while a do-extra-validation flag was enabled.

This approach took a while but in the end gave me confidence that the buffers are being properly donated. If a future change is implemented that causes arrays to always be flagged as deleted even if the buffer is not donated then these checks will no longer work.

Word of warning to anyone going down this path, the buffer donation logic when pytrees are involved is a lot less obvious and it took a while to get it so that all my buffers were being properly donated. The warning message does not always work when pytrees are involved.

I am closing this issue since it is effectively resolved.

kxs-dhenshall commented 1 week ago

Closing per previous comment.