iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.79k stars 603 forks source link

Does IREE support native cuda buffer VM invocation or is there any chance to improve the performance of result torch tensors construction? #18811

Open LWenH opened 5 days ago

LWenH commented 5 days ago

Hello, IREE developers, I've seen a similar question asked before: https://github.com/iree-org/iree/issues/11573#issuecomment-1371128261, but it seems lack of further discussion.

Suppose we have a common resnet-50 inference script to be run using Pytorch/IREE, and the input torch tensors are both already in cuda device:

import time
from transformers import AutoImageProcessor, ResNetModel
import torch
from datasets import load_dataset

dataset = load_dataset("huggingface/cats-image")
image = dataset["test"]["image"][0]
image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = ResNetModel.from_pretrained("microsoft/resnet-50")

device = "cuda"
model.to(device)

inputs = image_processor(image, return_tensors="pt")

# inputs are now device tensor list
inputs.to(device)

compiled_model = torch.compile(model, backend="turbine_cuda")
nwarm = 1
n = 10

# warm up
for i in range(nwarm):
    o1 = compiled_model(**inputs)

# n run for iree compiled model
t0 = time.time()
for i in range(n):
    o1 = compiled_model(**inputs)
t1 = time.time()
print(o1.last_hidden_state)
print(f"iree run: {t1 - t0} s")

For Vm context invocation, the current default method used by turbine is to construct DeviceArray and wrap buffer into a iree_hal_buffer_t type variable, which might introduce redundant copy in AllocateBufferCopy (https://github.com/iree-org/iree/blob/05bbcf1385146d075829cd940a52bf06961614d0/runtime/bindings/python/hal.cc#L165), I did a little hack to used a lower level API like iree_hal_cuda_buffer_wrap rather than iree_hal_device_transfer_h2d to wrap a cuda device buffer directly and avoid copy.

However, such process is still look very time consuming, I did a simple profile by nsight for such process in turbine (https://github.com/iree-org/iree-turbine/blob/1aa05e04d333bd2485337fbbe83c1c84a7c5e7f5/iree/turbine/dynamo/executor.py#L93), the profiler result shows that the process of iree_hal_buffer_t construction and unwraping iree_hal_buffer_t back to torch tensor takes about the same amount of time as vm context invoke:

Image

I know that the design principle of HAL layer is mainly to compatible with buffers of different device platforms and convenient for buffer management, but is there any possible to hack into iree, to let it support a more native cuda buffer vm invocation method for better performance? Or are there any other better methods to improve the performance of result tensors construction?

Thanks.

benvanik commented 5 days ago

You want the buffer import API, which let's you wrap a CUDA device pointer in an iree_hal_buffer_t: https://github.com/iree-org/iree/blob/05bbcf1385146d075829cd940a52bf06961614d0/runtime/src/iree/hal/allocator.h#L379

stellaraccident commented 4 days ago

The way the torch.compile wrapper is doing this is old and no one has fixed it. The support is in there for using dlpack to marshal the tensors but not on the torch.compile side yet.

LWenH commented 4 days ago

Yeah, I use an analogy buffer import API to wrap a cuda device buffer in a iree_hal_buffer_t handle, but it's a cuda platform specific API: https://github.com/iree-org/iree/blob/05bbcf1385146d075829cd940a52bf06961614d0/runtime/src/iree/hal/drivers/cuda/cuda_buffer.c#L37

However, through profiling result, I found that such purely buffer import/wrap procedure seems still very time cumsuming. For example, let's suppose we have a 400 length size input torch tensor list, which means that we have to do 400 times buffer import/wrap to create input variant list as vm invoation's input.

After vm invocation, we also have to use dlpack to unwrap iree_hal_buffer_view_t and marshal output torch tensor. Let's also suppose output return list's length size is 450, we have to do this marshal procedure 450 times. Overall, such buffer import and tensor marshal procedure takes about the same amount of time as vm context invoke.

Therefore, my curious is about are there any other more 'torch tensor native' hacking ways, to let IREE's VM support Pytorch tensor as input and producing torch tensor as output directly for better performance?

Or since that such torch.compile wrapper is an "outdated" way to invoke IREE, are there any other more Pytorch performance friendly usages to run above inference script?

Thanks.

stellaraccident commented 4 days ago

First off, hundreds of tensors for something like that isn't great. I guess that's just because torch.compile functionalizes everything. Not a great design for performance even if all of the latencies are eliminated.

I think we'll end up just writing a C extension properly at some point and doing some manner of fast bulk import. The device pointer import case can be fast pathed but taking that much stuff through so many layers of python marshaling is never going to be particularly good. When I was writing all of the dlpack marshal code, I was just cringing looking at all of the tolls that get exacted just to exchange device pointers. For casual use or a small number of things, it is ok. For hundreds, there's no way to make that good. Need to drop down to an optimized c implementation for that level of marshaling, I expect.

There are also lower level APIs for allocating a result slab as one backing allocation and then having that be used by iree. It's not super straightforward to get everything to line up, but if doing it over again (which we are going to have to do), I would have just started there.

Probably not the answer you're looking for.

stellaraccident commented 4 days ago

We could probably skip the dlpack stuff and just use data_ptr (https://pytorch.org/docs/stable/generated/torch.Tensor.data_ptr.html).

If doing that from python, it still wouldn't be as fast as it could be in c, but we could probably add some optimized API to the iree python API to take a bunch of torch tensors in bulk and get them into iree as fast as possible.

And then use the result buffer ABI to pass in one torch allocated result buffer and have IREE write into that vs returning individual tensors. For torch.compile, we'd need to do a little work to emit some kind of result size calculation so not just guessing. A lot of that information is in the metadata but needs to be found and used carefully.

Then basically get that down to one call into iree per invocation to handle all of it n+m calls to handle the inputs and results separately. It still wouldn't be free but could be made as fast as possible.

LWenH commented 4 days ago

Sure, thank you for your patience reply. It's actually the answer what I want to ask. BTW, could you please tell me where can I find that low level APIs for allocating a result slab as one backing allocation? So that I can go into more details here?

LWenH commented 4 days ago

BTW, I would like to ask when the IREE community would start or complete this part of VM input/output data transfer optimization development? Can we expect it to be determined at an approximate time point? Hope this doesn't cause any disturbance:-)

stellaraccident commented 3 days ago

This is raising in importance for some of the core contributors too and I am trying to see where it can slot in on a roadmap. The issue is that this is a relatively tricky thing and only a few people know all of the things involved.

Let me try to refresh some state on the output parameter thing. It requires generating the functions differently, which is something that I thought would always be needed to really serve the torch.compile case well but never had the time to actually flesh out the idea (the other reason for this is that it is needed for handling non contiguous input layouts without a copy).

These are all the reasons why I say that the current torch.compile thing is a demo and it was only ever written for CPU. Doing it right is a relatively advanced use of both the compiler and runtime: not hard per se but detailed.

Can't make a promise when, but I will at least try to write down a design for how this should work.

LWenH commented 1 day ago

Thank you Stellar. I look forward to seeing the community's progress in this regard soon. I will also continue to pay attention to the progress in this regard.