RWTH-ACS / cricket

cricket is a virtualization solution for GPUs
MIT License
150 stars 39 forks source link

Can this project support pytorch、tensorflow? #6

Open YixinSong-e opened 2 years ago

YixinSong-e commented 2 years ago

Reading through the code, I found that the current project is currently unable to support gpu virtualisation for python as it has to be a cuda binary?

n-eiling commented 2 years ago

pytorch is not supported because pytorch loads the CUDA code as a shared object. I started implementing support for this, but it is not finished, yet.

YixinSong-e commented 2 years ago

pytorch is not supported because pytorch loads the CUDA code as a shared object. I started implementing support for this, but it is not finished, yet.

Do you have any idea? I've been experimenting with this recently.

n-eiling commented 2 years ago

Once a shared object with CUDA code is loaded it calls this function (or something similiar): https://github.com/RWTH-ACS/cricket/blob/c0dec6241936958f6ddb55d8e5d33dfe8313efd6/cpu/cpu-client.c#L215 We cannot simply forward this call to the server because the shared object is not loaded in the server. So, before these calls we need to send the shared object to the server, load it and then forward the cudaRegisterFunction calls. It's probably not that much work, but I am focusing on something different right now.

Hope this answers your questions.

YixinSong-e commented 2 years ago

Once a shared object with CUDA code is loaded it calls this function (or something similiar):

https://github.com/RWTH-ACS/cricket/blob/c0dec6241936958f6ddb55d8e5d33dfe8313efd6/cpu/cpu-client.c#L215

We cannot simply forward this call to the server because the shared object is not loaded in the server. So, before these calls we need to send the shared object to the server, load it and then forward the cudaRegisterFunction calls. It's probably not that much work, but I am focusing on something different right now. Hope this answers your questions.

Yes, I do find this API, and I test the GVirtuS repo. this repo can send some version of the shared object. However, not all cubin files are in the same format. In the GVirtuS project, he used the ELF parsing of fatcubin files to obtain the parameter characteristics of the cuda kernel, but older versions of the compiled ELF format could not be parsed properly.

hudengjunai commented 1 year ago

I find the GVirtuS project. Does GViruS support pytorch or tensorflow in shared lib load app?

xial-thu commented 1 year ago

I find the GVirtuS project. Does GViruS support pytorch or tensorflow in shared lib load app?

No, hijacking of cuda internal APIs require huge amount of reverse engineering, and you'll waste you life dealing with fatbinary. It is not that simple as GVirtus shows. I would suggest you hijack pytorch API for remote execution.

nravic commented 1 year ago

It's probably not that much work, but I am focusing on something different right now.

@n-eiling could you elaborate on what needs to be done? I'm interested in using the project for some orchestration of tensorflow/pytorch and would be happy to contribute to getting this done

n-eiling commented 1 year ago

Any contribution is highly welcome! A while back I looked into building support for pytorch. The issue is that in pytorch the cuda kernels are located in shared objects that are loaded at runtime using dlopen/dlsym. For shared objects CUDA inserts some initialization code that registers the new kernels with the CUDA APIs. Without running this initialization code we cannot call the kernels using cuLaunchKernel. If I remember correctly I already tacked down the initialization to __cudaRegisterFunction.

So to make pytorch work with cricket, we would need to detect loading of CUDA shared objects on the client side, then copy the shared object to the server and there load it also and register the kernels.

nravic commented 1 year ago

Makes sense to me, thanks! I'll give this a try and come back with questions.

Will have to reread your paper with more detail :)

jin-zhengnan commented 1 year ago

Once a shared object with CUDA code is loaded it calls this function (or something similiar):

https://github.com/RWTH-ACS/cricket/blob/c0dec6241936958f6ddb55d8e5d33dfe8313efd6/cpu/cpu-client.c#L215

We cannot simply forward this call to the server because the shared object is not loaded in the server. So, before these calls we need to send the shared object to the server, load it and then forward the cudaRegisterFunction calls.

It's probably not that much work, but I am focusing on something different right now. Hope this answers your questions.

I have been following your project for a long time and have some questions that I would like to inquire about. Please help me answer them when you have time. Thank you very much. Cuda has some hidden functions, for example cudaRegisterFatBinary, cudaRegisterFatBinaryEnd, PushCallConfiguration, PopCallConfiguration and so on, why not directly call these hidden functions on the server after client hijacking, but instead convert them through other methods similar to cuModuleLoadData?@n-eiling

wzhao18 commented 1 year ago

Hi @jin-zhengnan, I recently experimented forwarding cudaRegisterFatBinary, cudaRegisterFatBinaryEnd, PushCallConfiguration,__ PopCallConfiguration calls from client applications to the server. And it is doable. You just have to intercept those calls and store the necessary data, such as the cubin data. And you repeat the same procedure in the server.

n-eiling commented 1 year ago

I also experimented with using these functions. These hidden functions are part of the runtime API, while I am using the cuModule APIs from the driver API. This is because I also use the driver API to launch kernels (cuLaunchKernekl) and, when I register the kernels using the hidden APIs I do not easily get the cuModule pointer required for this.

The hidden APIs seem to do basically the same as the cuModule APIs, however they register the host kernel stubs. These are the functions you call when launching a kernel in the runtime api (e.g., kernel<<<1, 1>>>()). We can't really intercept these, so instead Cricket intercepts the cuLaunchKernel call these stubs actually do.

I think working at the driver API level as much as possible is anyway desireable, as it reduces the server side code. So in a virtualized environment most code is executed in the VM or Container. Also the cuModule API is actually documented and hopefully more stable.

A few years back I also tried to implement Cricket entirely on the driver API level, so that the runtime API can also run on the client side. However, there are a lot more hidden functions called between both, which don't even show up in any public header. With more investigations this might be possible though...

wzhao18 commented 1 year ago

hi @n-eiling, thanks for responding. Sorry I just found out your repo a few days ago and I haven't checked your implementation. I plan to read your paper soon.

I have a question regarding your answer. What do you mean that "runtime api (e.g., kernel<<<1, 1>>>()). We can't really intercept these"? Because when I check the compilation, it seems all kernel<<<1, 1>>> calls are in fact macro to launch the kernel using cudalaunchkernel. From my experiments, all kernels launches can be intercepted from cudalaunchkernel. In fact, I don't know when cuLaunchKernel are used. If you could provide some insights on the difference between cudalaunchkernel vs culaunchkernel, and what cuModule APIs are for, it will be very helpful for me.

Thanks in advance!

n-eiling commented 1 year ago

Ah okay. So cudaLaunchKernel requires as first parameter the kernel stub. These are not available on the server. We could register something using cudaRegisterFunction but it is not as simple as copying a bunch of memory to the server and calling cudaRegisterFunction on it. There are a lot of data structure being used when launching a kernel. I imagine CUDA is doing something similar as Cricket is doing (i.e. creating a list of available kernels and storing some useful information on them), because for launching a kernel we need to know what architecture the kernel is compile for, what parameters it has at which offsets, if it is compressed, etc. I am also pretty sure that the Runtime API is actually also using the driver API to implement this .

wzhao18 commented 1 year ago

@n-eiling, thanks for the explanation. I will check out how Cricket achieves it. My intention is to natively support applications like pytorch, and I notice that they are calling cudalaunchkernel, that's why I have decided to use cudaRegisterFunction to register the kernel stub in the server. I was able to retrieve all the necessary data structure to launch a kernel at the server.

"I am also pretty sure that the Runtime API is actually also using the driver API to implement this ." For this, I am not sure this is the case, or at least not necessarily using the public driver API. Because I tried to intercept both runtime and driver API calls, it seems that the runtime API may not call a driver API call. I haven't checked this carefully. I may be wrong.

n-eiling commented 1 year ago

Ah cool. Do you have some code you can share? Would be interested in seeing if this has any advantages over my approach. I gave up when I found I had to rebuild the cuFunction and cuModule structures manually when using the hidden functions.

Have you seen #15 ? I am almost finished with pytorch support for Cricket. If you compile pytorch without kernel compression it already works. Haven't tested in too much depth though.

wzhao18 commented 1 year ago

Hi @n-eiling, sorry I am not able to share the code yet because it is a part of a WIP project. But the basic flow follows:

At the client end, I intercept these three functions - cudaRegisterFatBinary, cudaRegisterFunction, __ cudaRegisterFatBinaryEnd. I pass data intercepted from such calls to the server and repeat the same exact process there. Then, the client can forward cudaLaunchKernel call and execute at the server end.

I have only tested with a basic example, not even pytorch yet. I don't know whether this will hit issues later.

Tlhaoge commented 1 year ago

Hi ,I have some problems when running cricket in pytorch. I have pulled the latest code,and build pytorch locally with modify change the doces mentioned. my CUDA is 11.2 and cudnn is 8.9.2 in ths Tesla P4,but get this problem:

server: +08:01:00.423212 WARNING: duplicate resource! The first resource will be overwritten in resource-mg.c:145 +08:01:00.445168 WARNING: duplicate resource! The first resource will be overwritten in resource-mg.c:145 +08:01:00.445403 WARNING: duplicate resource! The first resource will be overwritten in resource-mg.c:145 +08:01:00.447247 WARNING: duplicate resource! The first resource will be overwritten in resource-mg.c:145 +08:01:00.448076 WARNING: duplicate resource! The first resource will be overwritten in resource-mg.c:145 +08:01:07.164339 ERROR: cuda_device_prop_result size mismatch in cpu-server-runtime.c:367 +08:02:22.370950 INFO: RPC deinit requested. +08:08:54.324012 INFO: have a nice day! client: `+08:00:36.417392 WARNING: could not find .nv.info section. This means this binary does not contain any kernels. in cpu-elf2.c:922 +08:00:36.418684 WARNING: could not find .nv.info section. This means this binary does not contain any kernels. in cpu-elf2.c:922 +08:00:36.420058 WARNING: could not find .nv.info section. This means this binary does not contain any kernels. in cpu-elf2.c:922 call failed: RPC: Timed out call failed: RPC: Timed out call failed: RPC: Timed out +08:02:01.851255 ERROR: something went wrong in cpu-client-runtime.c:444 Traceback (most recent call last): File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 242, in _lazy_init queued_call() File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 125, in _check_capability capability = get_device_capability(d) File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 357, in get_device_capability prop = get_device_properties(device) File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 375, in get_device_properties return _get_device_properties(device) # type: ignore[name-defined] RuntimeError

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/lwh/cricket/tests/test_apps/pytorch_minimal.py", line 39, in x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype) File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 246, in _lazy_init raise DeferredCudaCallError(msg) from e torch.cuda.DeferredCudaCallError: CUDA call failed lazily at initialization with error:

CUDA call was originally invoked at:

[' File "/home/lwh/cricket/tests/test_apps/pytorch_minimal.py", line 31, in \n import torch\n', ' File "", line 991, in _find_and_load\n', ' File "", line 975, in _find_and_load_unlocked\n', ' File "", line 671, in _load_unlocked\n', ' File "", line 843, in exec_module\n', ' File "", line 219, in _call_with_frames_removed\n', ' File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/init.py", line 798, in \n _C._initExtension(manager_path())\n', ' File "", line 991, in _find_and_load\n', ' File "", line 975, in _find_and_load_unlocked\n', ' File "", line 671, in _load_unlocked\n', ' File "", line 843, in exec_module\n', ' File "", line 219, in _call_with_frames_removed\n', ' File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 179, in \n _lazy_call(_check_capability)\n', ' File "/root/anaconda3/envs/py3.8/lib/python3.8/site-packages/torch/cuda/init.py", line 177, in _lazy_call\n _queued_calls.append((callable, traceback.format_stack()))\n'] +08:02:27.007890 ERROR: call failed. in cpu-client.c:213 +08:02:27.012036 INFO: api-call-cnt: 14 +08:02:27.012051 INFO: memcpy-cnt: 0`

Is my CUDA version wrong? or other reasons?

n-eiling commented 1 year ago

I currently use the docker container from pytorch commit 09d093b47b85b4958a7307249769fc0ee8658af9 This uses CUDA 11.7 There was some change to cudaGetDeviceProp somewhere in 11. Due to some weird API design in CUDA, there seems to be a size mismatch of the returned struct in some versions. Please try with a new one. Note that cudnnBackend is not working properly yet.

mehrdad-yousefi commented 1 year ago

@n-eiling I'm very interested to do some experiments by using cricket and utilizing PyTorch. Where can I find the documentation to build cricket with PyTorch support?

mehrdad-yousefi commented 1 year ago

@n-eiling Nevermind, I found it here: https://github.com/RWTH-ACS/cricket/blob/bedf45fd8301810009bb5df481262dcb52cf59f8/docs/pytorch.md