Open kxs-dhenshall opened 3 weeks ago
For 1) you can use https://jax.readthedocs.io/en/latest/transfer_guard.html
For 2) You can add arr.is_deleted()
check to make sure donation was successful? But note that in the future, we will delete the input array regardless of where donation was successful or not (which will help you?)
Thanks for the quick response, for (1) transfer_guard looks perfect.
For (2), I may do that but don't really want to keep around a bunch of checks. I have my own wrapper around the compiled function anyways, and can add the is_deleted check automatically so may just do that. I think in general that isn't the tool I was hoping for, I am going to dig around to see what kind of allocation stats I can find and if there are good allocation stats available I can use that to make sure that the memory allocated at the end of each iteration of the loop is not more than the memory allocated at the start.
For future context, I worked around the issue by using the suggestion and asserting that arr.is_deleted()
after every call to the JIT'tted function while a do-extra-validation flag was enabled.
This approach took a while but in the end gave me confidence that the buffers are being properly donated. If a future change is implemented that causes arrays to always be flagged as deleted even if the buffer is not donated then these checks will no longer work.
Word of warning to anyone going down this path, the buffer donation logic when pytrees are involved is a lot less obvious and it took a while to get it so that all my buffers were being properly donated. The warning message does not always work when pytrees are involved.
I am closing this issue since it is effectively resolved.
Closing per previous comment.
In order to work around long compile times and hidden re-compilations, a bunch of JIT compiled logic can be broken into a sequence of smaller functions.
One of the implications of this is that the control flow logic has to be host-side, since using jax.lax control flow functions triggers results in compilation of the target condition/body functions which is will effectively just compiling the original function. For example, if you were to break a function
f
into two functionsf1
andf2
your code may look something like this with the loop running on the host:Making this work efficiently requires a few extra considerations:
This works but is fragile and it can be difficult to protect against regressions without burdensome performance testing. One thing that works great is that (3.) can be dealt with easily through AOT compilation, which will throw an exception whenever a regression is introduced that mistakenly changes the type of a return value (e.g. an array in state is initialized to [1] with a data type of int32 when it should have been initialized as [1.] with data type of float32).
For use cases like this, I think that it would be useful to have mechanisms that would protect against the following two cases:
One approach would be to introducing new contexts that prevent these from happening inside a critical section of code. For example, with (2.) it could look something like the following (please ignore variable names, I did not put much thought into them):
Is this something that had been thought about or considered? I am not sure if this is an issue that other people have faced or not, and figured it was worth bringing up.