Open JiahaoYao opened 2 years ago
To clarify and add to the conversation here, I think having one ray process for each GPU would work, given you avoid pmap/pjit.
Having one Ray process per node might work, but I suspect you would have issues with JAX communicating across nodes.
If you wanted to have one ray worker for more than one node with more than one GPU, what would the recommendation be there? I imagine for really large models this could be useful.
Hey! Any updates on this? I'd really love jax support in Ray train
Description
this is assigned to myself.
Current approach for the jax trainer using the multi-gpu training:
TODO:
Use case
No response