openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

[JAX] Do not wait for the array deletion result #14861

Closed copybara-service[bot] closed 1 month ago

copybara-service[bot] commented 1 month ago

[JAX] Do not wait for the array deletion result

JAX Array.delete() had a mixed behavior depending on what the underlying implementation of an IFRT Array::Delete() does. JAX currently waits for the future result of IFRT Array::Delete() and surfaces it, but PjRt-IFRT always returns an OK without blocking, which makes this behavior moot. On a different IFRT runtime that can return an error from Array::Delete() results in a different behavior for JAX, and it can be also very costly if the error is available after a physical buffer deletion or after finishing an RPC request-response roundtrip.

This change resolves it by making JAX Array.delete() not wait for the result of Array::Delete(). This has three side effects:

  1. JAX Array.delete() is idempotent for every runtime.
  2. JAX Array.delete() will be always non-blocking.
  3. No errors from deletion will be surfaced to the user.

This is technically a deletion API semantics change, but since the implementation of deletion (at the jaxlib level) was not exactly matching the high-level API semantics (at the JAX level) and the users do not use the semantics, either, so we expect that this does not introduce regression in user workloads, while this change improves the consistency of the deletion API and its performance in some runtimes.