Open timofeev1995 opened 2 years ago
@timofeev1995 I cannot reproduce this on A40 x 4 machines.
Can you try to build the latest ft_triton image and try it on other machines (like A100, and A40)?
And also it would be great if you can share the config.ini
which is generated when you convert the model.
@PerkzZheng, thank you for reply!
I have the following config.ini:
model_name = gptj-6B
head_num = 16
size_per_head = 256
inter_size = 16384
num_layer = 28
rotary_embedding = 64
vocab_size = 50400
start_id = 50256
end_id = 50256
weight_data_type = fp32
Can you try to build the latest ft_triton image and try it on other machines (like A100, and A40)?
Do you mean there can be some issue with connection between gpus? I have following setup:
nvidia-smi topo -m
GPU0 GPU1 GPU2 GPU3 GPU4 CPU Affinity NUMA Affinity
GPU0 X PIX NODE NODE NODE 0-15,32-47 0
GPU1 PIX X NODE NODE NODE 0-15,32-47 0
GPU2 NODE NODE X PIX PIX 0-15,32-47 0
GPU3 NODE NODE PIX X PIX 0-15,32-47 0
GPU4 NODE NODE PIX PIX X 0-15,32-47 0
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
Thank you!
Do you mean there can be some issue with connection between gpus?
No, there is no communication between GPUs. I am just trying to understand if this happens only in certain circumstances.
Also, please add --ipc=host
if you are using docker containers.
Unfortunatelly with --ipc=host
used with docker run result (nullptr for each instance except main fastertransformer_0_0) is the same.
okay. Have you tried to build the image with the latest triton? I will also try the gptj config.ini you shared.
im going to get last tag for triton server and try it with GCP 2xA100 instance. Thank you!
Got it. Everything works as expected. Could we try to suggest what is the problem in my initial setup? Thank you!
I am not quite sure. Have you tried to run the inference with the latest image on 3 RTX3090s?
I tried to use the last triton-server container (22.09) with 3090, get the same issue.
what you could try to help us eliminate some potential factors:
I'm facing the same issue on 2xT4
@lakshaykc Hi, I have the same issue in two T4. Have you solved this problem.
No, I haven't been able to fix that issue yet.
@DunZhang @lakshaykc does it work if two model instances are on the same GPU ?
Yes, that works. Even though there is no communication between the two model instances or GPUs, it seems there is something related to cross GPU communication. But I have no idea what that could be. Two separate instances of GPU1 and GPU2 as @PerkzZheng mentioned above works, but obviously that is not a clean solution.
I cannot reproduce on 2xT4 machines, so maybe you can try to add the codes to ibfastertransformer.cc#L1793:
cudaPointerAttributes attributes;
cudaPointerGetAttributes (&attributes,output_buffer);
LOG_MESSAGE(
TRITONSERVER_LOG_VERBOSE, (std::string(" Memory_pointer_type: ") +
std::to_string(attributes.type))
.c_str());
and see if the pointers are device pointers (2).
Also, add cudaDeviceSynchronize()
at that point to see if anything changes. It looks there are some race conditions that pointers haven't been allocated before passing to cudaMemcpyAsync
.
@lakshaykc @DunZhang have a try with this branch and see if errors still exist. Let me know if you find anything helpful.
Ok, thanks, I'll try and let you know what a happens.
Hi @PerkzZheng ! I'm also facing the same issue and my setup is similiar to @timofeev1995 with several 3090. My topo is
GPU0 GPU1 GPU2 GPU3 GPU4 CPU Affinity NUMA Affinity
GPU0 X PIX PIX NODE NODE 0-15,32-47 0
GPU1 PIX X PIX NODE NODE 0-15,32-47 0
GPU2 PIX PIX X NODE NODE 0-15,32-47 0
GPU3 NODE NODE NODE X PIX 0-15,32-47 0
GPU4 NODE NODE NODE PIX X 0-15,32-47 0
I've been debugging this behaviour for some time and I found out that in my scenario the problem is in this cudaMemcpyAsync.
While I was in the debug mode I found out that switching the gpu by using cudaSetDevice
fixed the problem and I implemented a crutch which would cycle through the GPUs untill it's managed to copy the data.
I hope that maybe this is going to help you to understand this bug better.
I'll also check your bugfix asap and come back with the results.
I've tested your fix and unfortunately it did not resolve the problem for me. I double checked that the code indeed contains your fix in the docker image but I still get the same errors. I hope my previous comment may help you.
Just in case I'm getting the pinned buffer: failed to perform CUDA copy: invalid argument
error
Thanks. @hawkeoni
You mean doing cudaSetDevice
for each model instance ? as far as I know, cudaMemcpyAsync doesn't need a setDevice before it as the stream has the device id info.
Anyway, thanks. And I will push a quick fix that you guys can validate it.
we do have that for each model instance.. Do you mean doing this before responder.processTensor ?
it might be the case that the model instance constructor and the forwarder are using different threads (which are also different device contexts).
Hi, @PerkzZheng , I'm apologize I couldn't answer you sooner. I've once again tested your fix on setup with 2 to 5 gpus and it didn't work (I had my doubts, because I originally tried it on 5 gpus).
Thanks. @hawkeoni You mean doing
cudaSetDevice
for each model instance ? as far as I know, cudaMemcpyAsync doesn't need a setDevice before it as the stream has the device id info. Anyway, thanks. And I will push a quick fix that you guys can validate it.
I also believe that cuda copy commands do not require set device. Moreover I've written a simple script to test on my machine which just copies from different gpus to the cpu and it works with any setup of active device, so I'm not entirely sure what my fix does, but it works.
we do have that for each model instance.. Do you mean doing this before responder.processTensor ?
I've been debugging this problem and I found out that the copy doesn't work even at this moment of time.
we do have that for each model instance.. Do you mean doing this before responder.processTensor ?
I'll try to give you stack trace the best way I can:
This cudaMemcpyAsync fails each time it's not on the first gpu with cudaErrorInvalidValue
.
I've made a small fix where I use a loop to iterate over all gpus and try to do memcopy until I succeed or run out of gpus and it works. As you've said before and as I've checked on my setup - cuda copies work correctly without setting a device and there is no DeviceToDevice
communication, so I'm at a loss at this moment. Hope this helps.
Can you try this branch fix/multi_instance
?
I have replaced current branch with fix/multi_instance. The problem should be solved for now (My setup is 3xRTX3090). Notes: you should adjust dynamic batching queue time to get full utilization of gpus.
@byshiue Hi! I'm sorry it took me so long to answer, but I finally got to check your fix and it works on my setup! (RTX3090) Is it possible to merge it into the main branch?
@byshiue @PerkzZheng Hi! The fix from this comment works, are there any plans to merge it?
Hello. Than you for your work and framework!
My goal is to host n instances of GPTJ-6B on N graphic cards. I want to have N instances with one model in each. My setup is 3x3090 (one more host has 5x3090, but everything else is similar), docker container built according the repository readme, and fastertransformer config looking like that (for 3gpu setup):
I start triton using command:
CUDA_VISIBLE_DEVICES="0,1,2" /opt/tritonserver/bin/tritonserver --http-port 8282 --log-verbose 10 --model-repository=/workspace/triton-models
I face the issue which I can describe that way: I see that 3 instances running on 0-1-2 gpus. But when I request inference, I get correct generation result in 1/3 cases:
But in 2/3 cases (which are not routed to fastertransformer_0_0, but fastertransformer_0_1 / fastertransformer_0_2) I ger the following:
And for triton-http-client (python one) accordingly:
tritonclient.utils.InferenceServerException: pinned buffer: failed to perform CUDA copy: invalid argument
Is there any error in my setup and pipeline? What is correct way to setup N models on N gpus?
Thank you in advance!