Open mcuiaws opened 1 week ago
We controlled it by XLA_TPU_MAX_INFLIGHT_COMPUTATIONS
but I guess that's a TPU specified flag.
We can do something similar for neuron. How is XLA_TPU_MAX_INFLIGHT_COMPUTATIONS
implemented? Do you block inside PJRT_LoadedExecutable_Execute()
if the client queues up too many executions?
We defined TPU as a plugin and define the client create option in https://github.com/pytorch/xla/blob/91f5c8ad81a7d5149603d46a8ed9db6541432191/torch_xla/_internal/tpu.py#L352-L360. My understanding is that as long as you specified that PJRT client will handle the rest.
So it sounds like for TPUs you want 32 inflight. Does that mean it's by design that you are breaking torch lazy_graph_executor's contract of "only one asynchronous operation can execute at the same time, on a given device"?
only one asynchronous operation can execute at the same time
was a old design choice for the XRT runtime. Back then the async execution was implemented in the torch_xla level. It is more of a design constrain not a design choice.
Ever since we move to the PJRT runtime, the runtime itself supports async transfer and async execution, it is better to let runtime handle this kind of stuff. We want to make sure the program is not tracing bound so it is better to unblock from tracing as many graphs as possible.
Should we move the max inflight logic to torch_xla's prjt_computation_client.cc? PJRT's Execute APIs are asynchronous, and asynchronous APIs should not block, ideally...
hmm, torch_xla's execute and pjrt's execute are both async. max_inflight_computations
option in PJRT is the standard way to control max async execution, XLA uses for TPU, GPU and CPU. In our case when async execution in PJRT blocks, it won't relase the device lock hence the main thread that does the tracing will also block, which is exactly the behavior you need I think.
🐛 Bug
Here at AWS we have a single PJRT device plugin for both PyTorch and JAX, and recently we've made implements to our device plugin to make it work better with JAX. I.e. now
PJRT_LoadedExecutable_Execute()
is fully asynchronous, we queue up an execution and return immediately, and expect the caller to wait on thereturned_future
, whereas before, execution was synchronous and is completed whenPJRT_LoadedExecutable_Execute()
returns.As soon as we switched to the new implementation, we noticed that now torch_xla queues up as many executions it can without any throttling in PJRT or torch_xla, which causes us to easily exhaust device memory. It appears that now that there are no internal throttling mechanisms, and only explicit ones which needs to be triggered by user code:
xm.wait_device_ops()
is called, which calls down toWaitDeviceOps()
WaitDeviceOps()
However,WaitDeviceOps()
is a heavy hammer because it pauses the world until the entire pipeline is drained. Ideally we do not want to rely on this mechanism for throttling. Also we do not want the user to have to guess when to insert these calls to avoid running out of memory. Some sensible internal throttling mechanism is needed.The main issue here is that pjrt_computation_client.cc does not await on the
returned_future
from PJRT. It simply throws it away.However, according to torch's lazy_graph_executor, "only one asynchronous operation can execute at the same time, on a given device." This is controlled by a device lock, which is supposed to be held for the entire duration of the asynchronous execution. However, in torch_xla's xla_graph_executor.cpp, the device locks acquired by torch are released as soon as
ExecuteComputation()
returns, andExecuteComputaton()
does not actually wait for the actual computation to complete. Therefore, torch lazy_graph_executor's throttling mechanism is defeated here.