ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.12k stars 5.61k forks source link

[Ray Trainer] `JaxTrainer`: Improve documentation, and explicitly mention some caveats. #25234

Open JiahaoYao opened 2 years ago

JiahaoYao commented 2 years ago

Description

this is assigned to myself.

Current approach for the jax trainer using the multi-gpu training:

TODO:

Use case

No response

cameronrutherford commented 2 years ago

To clarify and add to the conversation here, I think having one ray process for each GPU would work, given you avoid pmap/pjit.

Having one Ray process per node might work, but I suspect you would have issues with JAX communicating across nodes.

If you wanted to have one ray worker for more than one node with more than one GPU, what would the recommendation be there? I imagine for really large models this could be useful.

AshishKumar4 commented 1 month ago

Hey! Any updates on this? I'd really love jax support in Ray train