Closed josevalim closed 3 weeks ago
I looked into this, here are a few notes.
Jax has external callbacks, in particular used for printing. On CPU and GPU they use a custom call that invokes a the Python callback function (for GPU the custom call copies data off the GPU first and invokes the callback exactly the same as CPU). For TPU they use send/recv operations (and eventually invoke the callback also). I found one PR with more context https://github.com/jax-ml/jax/pull/13759 (I also asked some questions there, but no answer so far).
We could implement the custom calls with enif_send
to a named process, the main challenge is to attach enough information to the message to know which computation/callback it corresponds to. Ideally we want this to be a compile-time information that we can encode as constant inputs to the custom call MLIR op. I think what could work is generating id for the hook function via fun |> term_to_binary |> md5
; the executing process would subscribe to the named process using that id.
On a separate note, we could possibly even add a callback API for getting hook return value into defn. One idea I have would be to create a resource in the custom call that would hold a mutex, a condition variable and a value field. The resource ref would be a part of the message. After sending the message we would call enif_cond_wait
. On the Elixir side, once the message is routed to the right process and the callback executes, we would call another NIF with the resource ref and the callback result. In the NIF we would put the result in the resource value field and finally call enif_cond_signal
. This could be useful, but not necessarily that useful. For Jax it makes sense, because they can invoke numpy/scipy functions not implemented in Jax, but in our case we would likely call NIFs instead, so we may as well skip Elixir side altogether as in #1519.
All that said, implementing the custom calls is rather annoying, and it doesn't automatically translate to more platforms. For example, if we ended up adding the Metal plugin, we would need to implement another custom call (and I expect custom call may not even be a thing there, so perhaps it needs yet another mechanism). The only reason to make that change now would be switching from the StreamExecutor GPU implementation to the PjRt plugins (which don't support infeed/outfeed), but it doesn't really make a difference for the end user, and we can likely maintain compatibility if we do it in the future (i.e. XLA_TARGET would download/register the necessary plugins, so existing setup would work). Given many of the decisions happen internally and xla/jax/tensorflow is multiple efforts, things may shift in the future. For all these reasons we decided to wait and make changes once we really need.
I think that the custom call could receive, as an MLIR attribute, something that encodes a ref
that pertains to that specific computation. This should be enough for each XLA/IREE runtime backend to call enif_send
to that process with the ref plus whatever data we want to send together.
edit: this is specifically for sending things out from the computation back to Elixir -- really useful for monitoring values in a long-running computation, for instance, or debugging via print_value
@polvalente to pass it as MLIR attribute (or constant input, since that's how we pass info to custom calls) it needs to be known at MLIR-compile time, so it can't be any transient information like ref/pid, especially that we can even cache the executable on disk.
@jonatanklosko I hadn't considered the possibility of model serialization. We can actually pass a PID (or any term for that matter) as a runtime argument if we use Nx.from_binary(:erlang.term_to_binary(pid), :u8)
.
Then, the custom call can react to that argument by using binary_to_term to obtain the target PID.
This would add a new input to the function, but it's an alternative to having a fixed value.
@polvalente I thought about using inputs, but it seems to me that it's too much. The custom call alone makes the MLIR elixir-specific and not necessarily portable, but having a specific input is a step further.
We need to understand if it will make interoperability better or worse, in particular in regards to IREE and Apple Metal plugin. Note that JAX in particular emits custom MLIR code for these operations.