Issue: The Ray worker could fail to actuall acquire the TPU devices if another process is already using the TPU. This would cause Jax to fall back to CPUs in the Ray worker process.
Fix: If running with Ray, all processes by default should have JAX_PLATFORMS set to cpu (this can be set from the yaml configuration, which is not included in this PR). When a Ray worker is created, override the runtime environment to tpu,cpu. This ensures that the Ray worker can acquire access to the TPU device.
Issue: The Ray worker could fail to actuall acquire the TPU devices if another process is already using the TPU. This would cause Jax to fall back to CPUs in the Ray worker process.
Fix: If running with Ray, all processes by default should have JAX_PLATFORMS set to
cpu
(this can be set from the yaml configuration, which is not included in this PR). When a Ray worker is created, override the runtime environment totpu,cpu
. This ensures that the Ray worker can acquire access to the TPU device.