Open mpjlu opened 1 year ago
there is the same problem for bloomz-560m, tensor parallel with 2 or 4 gpu, there is always 1 gpu the kernels gap is very long. like this picture. red box is all_reduce end time. The gap between kernels of GPU0 is very long.
Hi @mpjlu,
I noticed a similar issue when benchmarking another model recently and came across your post. After delving deeper, it seems to me that the bottleneck actually lies in the kernel launch gaps across all GPUs (note the CUDA API row inside Nsight).
The reason the GPU1 execution gap appears small in the picture above seems to be because the tasks are already queued while waiting for the all_reduce to finish. You can double check the correlation_id to find the corresponding kernel launch timestamp.
@jsheng-jian thanks. Your are right, the bottleneck is kernel launch gaps because 2gpu kernel run time cannot cover kernel launch time. I have rewritten the transformer block code in C++, the issue is solved.
Describe the bug DeepSpeed-inference 2GPU performance is lower than 1GPU on Bloomz 7.1B
To Reproduce run the flowing code using two V100 or A100 GPUs. with this command: deepspeed --num_gpus 2 --master_port 60000 test.py `from transformers import AutoModelForCausalLM, AutoTokenizer import deepspeed
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1" checkpoint = "bigscience/bloomz-7b1"
tokenizer = AutoTokenizer.from_pretrained(checkpoint) model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto", device_map="auto")
model = deepspeed.init_inference( model=model, mp_size=2,
dtype=torch.float16, replace_method="auto", replace_with_kernel_inject=True, )
inputs = tokenizer.encode("Translate to English: Je t’aime.", return_tensors="pt").to("cuda") outputs = model.generate(inputs) print(tokenizer.decode(outputs[0]))`
The performance of two gpu is lower than one gpu. The main reason is all_reduce time is too long. The reason of all_reduce too long is because there are many space during the kernel execution, as the picture shows: the all_reduce time of GPU0 is much longer than GPU1 all_reduce time, because there is long spaces during GPU1 op execution.
Expected behavior 2 GPU inference show be faster than 1 GPU.
Launcher context deepspeed --num_gpus 2 --master_port 60000 test.py Docker context nvcr.io/nvidia/pytorch:22.09-py3