pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Device init before `xmp.spawn()` #7765

Open miladm opened 3 months ago

miladm commented 3 months ago

🐛 Usability Bug

Consider the following scenario:

Helper Code:

def is_tpu_available():
    devices = xm.xla_real_devices()
    if len(devices) > 0:
        return 'TPU' in devices[0]
    return False

Calling APIs like xm.xla_real_devices() before xmp.spawn, causes failures except when nprocs=1. This execution pattern looks confusing for the user. The execution patter / error generation should be consistent no matter what nprocs is used.

# Working Case
if __name__ == '__main__':
    if is_tpu_available(): # TPU device check
        xmp.spawn(_mp_fn, nprocs=1)
    else: #e.g. CUDA device
        train_gpt()
# Failing Case
if __name__ == '__main__':
    if is_tpu_available(): # TPU device check
        xmp.spawn(_mp_fn)
    else: #e.g. CUDA device
        train_gpt()

Error:

RuntimeError: Runtime is already initialized. Do not use the XLA device before calling xmp.spawn.
JackCaoG commented 3 months ago

AFAIK this is a hard requirement for PJRT, we added the error message to make it more discoverable.

JackCaoG commented 3 months ago

There might be a way to implement is_tpu_available without init the runtime through, like just checking the device dir. @will-cromar I think we have some examples somwhere

will-cromar commented 3 months ago

This function safely checks how many TPUs are actually attached via PCI: https://github.com/pytorch/xla/blob/dc3b265ae84f2fc2dfeaa20d78e8634b85a2267d/torch_xla/_internal/tpu.py#L101-L114

Why is the code path for TPU different for TPU vs CUDA at all? torchrun with no spawn should work for both cases.

edit: That's to say, is_tpu_available would be trivial to implement since we already effectively do this:

https://github.com/pytorch/xla/blob/dc3b265ae84f2fc2dfeaa20d78e8634b85a2267d/torch_xla/runtime.py#L65-L67

We should have a clear idea of what the use case is before adding it to our high-level API. I don't think using it to select between launch modes makes sense here.