iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 614 forks source link

[runtime][hip] The driver does not manage its devices properly in a multi-device context #18077

Open sogartar opened 3 months ago

sogartar commented 3 months ago

There are a number of HIP functions that assume a selected device for the current thread and they operate on this device. For example hipModuleLoadDataEx. We need to set the correct HIP device during IREE HAL HIP calls. The correct behavior should be that the HAL restores the selected HIP device before returning control to the caller. From a user's point of view the IREE HAL should not change the selected device unexpectedly.

benvanik commented 3 months ago

This was/is an issue in CUDA as well and the last time a fix was attempted it ended in sadness. In CUDA changing the device would flush it and also synchronize, leading to switches taking up to 10ms to full seconds depending on whether there was any work in-flight. HIP may be better but I'm doubtful given that it's mostly a flaw with the design of the API they inherited.

During initialization (module loading/etc) this isn't a critical issue but is important when managing execution. Most real CUDA multi-device usage will take the model of one OS thread per device that has its respective device always set and then managing per-device submissions via those threads. We already have the pending action queue and could just lean on this: during normal execution all queue work should happen via the action queue thread and since each device has its own thread it's naturally always set appropriately. If we wanted faster startup we could move module loading to the action queue as well and that would allow multiple devices to initialize modules at the same time.

The big thing to work is that any operation we can avoid needing to make a TLS-dependent HIP API call on the "main thread" (whatever arbitrary thread the user is calling IREE on and thus coming in through the IREE HAL APIs on) should either use cached information from startup (for queries/etc), the action queue, or be explicitly OK with potentially blocking behavior (which is ideally nothing, but ok to work to that point).

sogartar commented 3 months ago

Thank you for the heads up on the device change inefficiencies. That would save me from taking the wrong approach. I will make a list of the HIP and higher level IREE ops that need device setting to see what can be moved to the action queue thread.

sogartar commented 3 months ago

@AWoloszyn, you mentioned that you wanted to refactor the HIP runtime and possibly to merge the common logic of the action queue between the CUDA and HIP drivers. I was wonder if it makes sense to do first the refactoring and then to focus on this issue.

benvanik commented 3 months ago

Good point - if we use the queue for doing submissions then having that be common would also solve the CUDA bugs we have (or at least fix most of the issues and leave just a few whack-a-mole ones).

AWoloszyn commented 3 months ago

Just as a note: its more than just the submissions that have TLS in the driver. hipEventCreate (and I think cuEventCreate as well) tie the event to the "current" device. So, we have to be pretty careful about those calls as well, and we will have to make sure we pull those off the user thread as well.

benvanik commented 3 months ago

Good point - we should pool all of those we can anyway and only submit to the queue (or force a blocking switch) when the pool needs to grow. Async allocs also require some work but it's not too different than what we need to do on other backends and we may be able to share things there (I've got a WIP PR for doing async allocs on CPU where we don't need to actually reserve memory at the HAL API time and instead can put it in the queue). The good news is all this can happen incrementally so long as we know where we're going (and get the queue commoned and generalized) - so long as we can get some real async traces from real programs and catch where we hit the major blocking ops we can burn them down.

benvanik commented 3 months ago

Oh, for module loading we really need to fix the compiler to link executables together - ideally we should only have one HAL executable per device (though there may be more) and embedded within that we'll have multiple HIP/CUDA modules (or a library, etc) - that makes offloading to a worker thread even less costly as we'd only be sending one request for compilation to the worker thread instead of one per dispatch as we would be today.

sogartar commented 3 months ago

It turned out that acquiring the information of what CUDA/HIP API functions require the current device to be set is not as straight forward as reading the doc. Even the CUDA doc is missing this information on a lot of function. I am working this out through experimentation now. Regardless of what are the per-function requirements, HIP must have at least as relaxed requirements as CUDA does. My assumption is that if this is not the case, hipifying will be not straight-forward. We can focus on maintaining a unified implementation as much as possible. This may be a pessimization for HIP if it has less strict requirements for some functions. Then we may end up going through the device worker thread unnecessarily adding extra latency. For example I have checked that launching a kernel in HIP does not require the current thread device to be the same as the one associated with the stream we are launching on. According to the CUDA C++ Programming Guide in CUDA you can't launch kernels on other than the current device. We would like to schedule functions on the device thread with coarse granularity and make as few round trips to the device thread and back as possible. I think in a lot of cases the hand-off to the device thread would happen right after the IREE HAL API vtable dispatch. Some of them make multiple CUDA/HIP API calls to complete.

sogartar commented 3 months ago

I decided to not do guess work and raised questions for both HIP and CUDA. Even if calls right now don't fail they may start failing in subsequent releases.