ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.27k stars 5.81k forks source link

[Core] JAX DeviceArray zero-copy #48960

Open Joshuaalbert opened 5 days ago

Joshuaalbert commented 5 days ago

Description

I could not find in documentation if this already exists, however the idea is simple. Zero copy JAX for host-residing device arrays, and automatic transfer from accelerator if it's not on host already. Thereafter treating like a numpy array in object store, and materialising into JAX host on the other side.

Use case

No response

yuriy-yarosh commented 4 days ago

ray is primarily utilizing arrow dataframes, which are nullable and casted to float64 when converted to numpy... which results in precision loss. Jax does not support nullability, yet. https://github.com/jax-ml/jax/issues/16289

I'd suggest using ray remotes only to initialize Jax Distributed and omit usage of pyarrow altogether.

You can perform basic service discrovery and job planning in ray, but multi-tpu / multi-gpu setups should work better with NCCL and Jax distributed, due to absence of dataframes transforms and posibility of actual zero-copy (needs to be confirmed).

There are numerous limitations in Jax Distributed workflows and jax.Array has it's own downsides in terms of scheduling and fault-tolerance (there's not much, or enough, of it), compared to Ray. Some form of manual pyarrow dataframes tranforms may be sufficient, as well...

yuriy-yarosh commented 4 days ago

So, as far as ray<->jax interop support goes... it probably makes sense to add tensor conversion to convert_to_tensor funcs over here https://github.com/ray-project/ray/blob/master/rllib/utils/framework.py#L21 and here https://github.com/ray-project/ray/blob/master/rllib/env/env_runner.py#L162

Created a new issue #48987, hopefully, folks around will help me to clarify/confirm my assumptions.