pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

How are PJRT asynchronous executions throttled by torch_xla? #8380

Open mcuiaws opened 1 week ago

mcuiaws commented 1 week ago

🐛 Bug

Here at AWS we have a single PJRT device plugin for both PyTorch and JAX, and recently we've made implements to our device plugin to make it work better with JAX. I.e. now PJRT_LoadedExecutable_Execute() is fully asynchronous, we queue up an execution and return immediately, and expect the caller to wait on the returned_future, whereas before, execution was synchronous and is completed when PJRT_LoadedExecutable_Execute() returns.

As soon as we switched to the new implementation, we noticed that now torch_xla queues up as many executions it can without any throttling in PJRT or torch_xla, which causes us to easily exhaust device memory. It appears that now that there are no internal throttling mechanisms, and only explicit ones which needs to be triggered by user code:

  1. when xm.wait_device_ops() is called, which calls down to WaitDeviceOps()
  2. when tensor is read, which internally calls WaitDeviceOps() However, WaitDeviceOps() is a heavy hammer because it pauses the world until the entire pipeline is drained. Ideally we do not want to rely on this mechanism for throttling. Also we do not want the user to have to guess when to insert these calls to avoid running out of memory. Some sensible internal throttling mechanism is needed.

The main issue here is that pjrt_computation_client.cc does not await on the returned_future from PJRT. It simply throws it away.

However, according to torch's lazy_graph_executor, "only one asynchronous operation can execute at the same time, on a given device." This is controlled by a device lock, which is supposed to be held for the entire duration of the asynchronous execution. However, in torch_xla's xla_graph_executor.cpp, the device locks acquired by torch are released as soon as ExecuteComputation() returns, and ExecuteComputaton() does not actually wait for the actual computation to complete. Therefore, torch lazy_graph_executor's throttling mechanism is defeated here.

JackCaoG commented 1 week ago

We controlled it by XLA_TPU_MAX_INFLIGHT_COMPUTATIONS but I guess that's a TPU specified flag.

mcuiaws commented 1 week ago

We can do something similar for neuron. How is XLA_TPU_MAX_INFLIGHT_COMPUTATIONS implemented? Do you block inside PJRT_LoadedExecutable_Execute() if the client queues up too many executions?

JackCaoG commented 1 week ago

We defined TPU as a plugin and define the client create option in https://github.com/pytorch/xla/blob/91f5c8ad81a7d5149603d46a8ed9db6541432191/torch_xla/_internal/tpu.py#L352-L360. My understanding is that as long as you specified that PJRT client will handle the rest.

mcuiaws commented 1 week ago

So it sounds like for TPUs you want 32 inflight. Does that mean it's by design that you are breaking torch lazy_graph_executor's contract of "only one asynchronous operation can execute at the same time, on a given device"?

JackCaoG commented 1 week ago

only one asynchronous operation can execute at the same time was a old design choice for the XRT runtime. Back then the async execution was implemented in the torch_xla level. It is more of a design constrain not a design choice.

Ever since we move to the PJRT runtime, the runtime itself supports async transfer and async execution, it is better to let runtime handle this kind of stuff. We want to make sure the program is not tracing bound so it is better to unblock from tracing as many graphs as possible.

mcuiaws commented 4 days ago

Should we move the max inflight logic to torch_xla's prjt_computation_client.cc? PJRT's Execute APIs are asynchronous, and asynchronous APIs should not block, ideally...

JackCaoG commented 4 days ago

hmm, torch_xla's execute and pjrt's execute are both async. max_inflight_computations option in PJRT is the standard way to control max async execution, XLA uses for TPU, GPU and CPU. In our case when async execution in PJRT blocks, it won't relase the device lock hence the main thread that does the tracing will also block, which is exactly the behavior you need I think.