Open Joshuaalbert opened 5 days ago
ray is primarily utilizing arrow dataframes, which are nullable and casted to float64 when converted to numpy... which results in precision loss. Jax does not support nullability, yet. https://github.com/jax-ml/jax/issues/16289
I'd suggest using ray remotes only to initialize Jax Distributed and omit usage of pyarrow altogether.
You can perform basic service discrovery and job planning in ray, but multi-tpu / multi-gpu setups should work better with NCCL and Jax distributed, due to absence of dataframes transforms and posibility of actual zero-copy (needs to be confirmed).
There are numerous limitations in Jax Distributed workflows and jax.Array has it's own downsides in terms of scheduling and fault-tolerance (there's not much, or enough, of it), compared to Ray. Some form of manual pyarrow dataframes tranforms may be sufficient, as well...
So, as far as ray<->jax interop support goes... it probably makes sense to add tensor conversion to convert_to_tensor funcs over here https://github.com/ray-project/ray/blob/master/rllib/utils/framework.py#L21 and here https://github.com/ray-project/ray/blob/master/rllib/env/env_runner.py#L162
Created a new issue #48987, hopefully, folks around will help me to clarify/confirm my assumptions.
Description
I could not find in documentation if this already exists, however the idea is simple. Zero copy JAX for host-residing device arrays, and automatic transfer from accelerator if it's not on host already. Thereafter treating like a numpy array in object store, and materialising into JAX host on the other side.
Use case
No response