google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
33 stars 14 forks source link

Set JAX_PLATFORMS to "tpu, cpu" for ray worker #145

Closed richardsliu closed 2 months ago

richardsliu commented 2 months ago

Issue: The Ray worker could fail to actuall acquire the TPU devices if another process is already using the TPU. This would cause Jax to fall back to CPUs in the Ray worker process.

Fix: If running with Ray, all processes by default should have JAX_PLATFORMS set to cpu (this can be set from the yaml configuration, which is not included in this PR). When a Ray worker is created, override the runtime environment to tpu,cpu. This ensures that the Ray worker can acquire access to the TPU device.