pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 467 forks source link

allowing `xm.get_ordinal()` and default device in `xm.xla_device()` in PJRT #3692

Closed ronghanghu closed 2 years ago

ronghanghu commented 2 years ago

🚀 Feature

Context: The new PRJT runtime provides many advantages over the XRT runtime and is a really great new feature in PyTorch/XLA. On the other hand, the current PJRT MNIST and PJRT ImageNet examples require

These differences will likely break a lot of existing user code that directly uses xm.get_ordinal() and xm.xla_device() for rank and XLA devices (e.g. the FSDP implementation and the MoCo v3 implementation and many other codebases) and introduces bugs and overheads in transiting to PJRT.

Proposed new feature:

For compatibility with existing code, it would be great to allow directly getting the rank via xm.get_ordinal() (which should return a value between 0 and xm.get_world_size()) and directly setting the XLA device via xm.xla_device() when using PJRT.

This can be implemented perhaps by storing the index as a thread-local variable and reading it in xm.get_ordinal() and xm.xla_device() when using PJRT.

Also, a related bug fix to the PJRT MNIST (and ImageNet) example is that we should add * 2 to CLOUD_TPU_TASK_ID when computing rank on TPU v3, as mentioned in https://github.com/pytorch/xla/pull/3607#discussion_r915347367.


Also, below are a few of my throughs and discussions related to PJRT.

Regarding random ops in PyTorch (and e.g. NumPy)

A major difference between PJRT and XRT on TPU v3 is that PJRT needs to control 2 TPU cores in a chip from each process, instead of having 1 process per TPU core. While most of the differences can be bridged by spawning 2 threads per process on TPU v3, the random seeds can only be set at the process level and consequently random PyTorch ops or NumPy ops can not exactly be matched between PJRT and XRT. This could have a few implications.

On enforcing deterministic behaviors: This PJRT vs XRT difference in the random seed can manifest in several places when the users try to enforce deterministic behaviors (such as setting PyTorch data loaders or other random ops to always give the same output). This is related to a de-facto MPMD assumption in most PyTorch codebases that there's one process per device (typically GPU, as in the DDP case). From my perspective, although this is undesired, it is still fine to have this discrepancy between PJRT and XRT -- when using PJRT on TPU v3, users need to set random seeds at the process level and share the random number generator between the two threads. If one really needs deterministic behavior, then the users need to allocate and specify a separate random number generator in each thread. (I'm happy to hear other people's opinions on this!)

On model initialization: This PJRT vs XRT difference in the random seed is probably why we need to pass the model state dict as part of the function in the PJRT MNIST example, which makes it more complicated than the XRT MNIST example that has the same state dict per TPU by setting the same random seed in each process. However, I think it's fine to have this additional hassle of passing in the state dict in PJRT MNIST example. And actually, in real use cases, one would rarely want to set the same random seed across different processes as in the XRT MNIST example, since doing so could lead to a lot of undesired behaviors such as having the same dropout masks, same data augmentation, same sampled noise variables in e.g. GAN training across all TPU cores -- all of which could hurt training.

The more commonly used way to initialize a model in PyTorch is to set different random seeds across devices, and then synchronize the model parameters so that each TPU device starts with the same initial model parameters. This is implemented in PyTorch DDP by broadcasting model state dicts from rank 0 to others during initialization. Previously in PyTorch XLA use cases, I also implemented a similar parameter and buffer broadcasting as follows

def broadcast_xla_master_model_param(model):
    """
    Broadcast the model parameters from master process to other processes
    """
    parameters_and_buffers = []
    is_master = xm.is_master_ordinal(local=False)
    for p in chain(model.parameters(), model.buffers()):
        # Set all params in non-master devices to zero so that all_reduce is
        # equivalent to broadcasting parameters from master to other devices.
        scale = torch.tensor(1 if is_master else 0, dtype=p.data.dtype)
        scale = scale.to(p.data.device)
        p.data.mul_(scale)
        parameters_and_buffers.append(p.data)
    xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers)
    xm.rendezvous("broadcast_xla_master_model_param")

With parameter synchronization above, one doesn't need to set the same random seed for each TPU core when initializing model parameters, and therefore doesn't need to first build the entire model and then pass model.state_dict() to common.run_pjrt_multiprocess as in the PJRT MNIST example -- it's often impossible to first build the entire model before spawning processes in a lot of use cases.

Regarding internal SPMD over the two cores on TPU v3

Another potential way (as discussed in an offline chat with @JackCaoG) to bridge the random seed difference between PJRT and XRT on TPU v3 is to only spawn 4 processes (without the 2 threads per process) and then do internal SPMD over the two cores in each process. This way, there are only 4 visible devices to the users so it's still one process per device.

However, I feel this internal SPMD is not trivial to implement without requiring users to do a lot of additional axis annotation. For example, suppose we do data parallelism between the two cores via implicit SPMD, and users want to compute something like y = matmul(W, x), how do we know whether W or x (and which axis in them) is the "data" dimension that needs to be partitioned in SPMD (without requiring users explicitly telling us)? Also, what if the data/batch dimension is not divisible by the number of cores (e.g. having a batch size of 4 instead of 8 on v3-8, which seems legitimate if users can only see 4 TPU devices)? It seems to me that internal and implicit SPMD is hard to cover all general cases.

Also, it could be hard to debug collective ops when we have both internal SPMD over the 2 cores in a chip and users' manually specified collective ops between chips. So it seems more preferable to me to have either a) SPMD over the entire TPU mesh or b) entirely user-specified collective ops and sharding. A mixture of the two will make debugging and profiling quite complicated, so I feel this at least shouldn't be the default behavior on TPU v3.

cc: @JackCaoG

ronghanghu commented 2 years ago

xm.rendezvous seems to be another API that needs patching for PJRT (and causes crashing when I was trying to port a few previous cases to PJRT).

The implementation xm.rendezvous in https://github.com/pytorch/xla/blob/2bdd718b4b7309b5868825e261ae05bef6be548f/torch_xla/core/xla_model.py#L1058 calls get_ordinal() which is not implemented yet for PJRT.

However, it still seems to fail even when I tried to manually run

rank = int(os.getenv('CLOUD_TPU_TASK_ID', 0)) * 2 + index  # manually compute the rank on TPU v3
torch_xla._XLAC._xla_rendezvous(rank, 'example_message_for_rendezvous', b'', [])

on TPU v3. So I wonder is _xla_rendezvous expected to work together with threads in PJRT on TPU v3?

ronghanghu commented 2 years ago

Besides, I suspect that the current xm.all_gather (and perhaps also xm.reduce_scatter) could be problematic under PJRT.

When pinning layout in all-gather, _all_gather_using_all_reduce relies on get_ordinal() in https://github.com/pytorch/xla/blob/2bdd718b4b7309b5868825e261ae05bef6be548f/torch_xla/core/xla_model.py#L623 that needs to be patched. This seems to break several existing PyTorch XLA codebases in our internal tests on PJRT. However, there were still large discrepancies between XRT and PJRT in several cases even after we tried patching get_ordinal() on our end to return the ranks on TPU v3. We're trying to pin down the issue and see whether it could be related to other all-gather and reduce-scatter issues.

When not pinning layout in all-gather, since there isn't an explicit thread index (unlike the xm.xla_device(index) case to get XLA devices) passed into torch_xla._XLAC._xla_all_gather in https://github.com/pytorch/xla/blob/2bdd718b4b7309b5868825e261ae05bef6be548f/torch_xla/core/xla_model.py#L685-L686, does it correctly handle rank needed for all-gather for each thread in PJRT? (And similarly for _xla_all_gather_out, _xla_all_gather_out, and _xla_reduce_scatter_out that also potentially need rank information).

will-cromar commented 2 years ago

Thanks @ronghanghu for the awesome detail on this issue!

You touched on a few issues here. For now, I'll focus on implementing xm.get_ordinal, xm.xla_device, etc. to have the right default behaviors with PJRT.

We haven't decided how we want to handle the complications with v3 long-term. Short-term, we are likely going to stick with 4 processes with 2 threads each, so I'll make sure that any changes I make here support that. Thanks for the feedback on the SPMD idea that @JackCaoG and I were kicking around.