Open miladm opened 3 months ago
AFAIK this is a hard requirement for PJRT, we added the error message to make it more discoverable.
There might be a way to implement is_tpu_available
without init the runtime through, like just checking the device dir. @will-cromar I think we have some examples somwhere
This function safely checks how many TPUs are actually attached via PCI: https://github.com/pytorch/xla/blob/dc3b265ae84f2fc2dfeaa20d78e8634b85a2267d/torch_xla/_internal/tpu.py#L101-L114
Why is the code path for TPU different for TPU vs CUDA at all? torchrun
with no spawn
should work for both cases.
edit: That's to say, is_tpu_available
would be trivial to implement since we already effectively do this:
We should have a clear idea of what the use case is before adding it to our high-level API. I don't think using it to select between launch modes makes sense here.
🐛 Usability Bug
Consider the following scenario:
Helper Code:
Calling APIs like
xm.xla_real_devices()
beforexmp.spawn
, causes failures except whennprocs=1
. This execution pattern looks confusing for the user. The execution patter / error generation should be consistent no matter whatnprocs
is used.Error: