Closed zoubaihan closed 1 year ago
Here's a simple piece of code which reproduce the issue
from transformers import AutoModelForCausalLM, AutoTokenizer
import threading
import tensor_parallel as tp
import torch
checkpoint = "bigcode/starcoderbase"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.float16,low_cpu_mem_usage=True, trust_remote_code=True)
model = tp.tensor_parallel(model, sharded=False)
def task():
inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda")
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False))
thread1 = threading.Thread(target=task)
thread2 = threading.Thread(target=task)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
@zoubaihan sadly, tensor_parallel models can't be used concurrently because otherwise the communications break.
For example, the code @linfan provided showcases how it happens when two concurrent calls reach a round of communication both trying to broadcast the the first part of some tensor, overwriting each other and leaving the second part None
, but both thinking that the broadcast is complete since two parts were communicated in total.
For your specific case I'd recommend to find a library to batch requests and thus utilize the resources more efficiently. But forward calls themselves should not be concurrent.
tldr: no concurrency, use locks
When I use
tensor_parallel
to make my model inference on two GPUs(the model is deployed on Flask pywsgi), I also use gunicorn to manage my flask app to make it can accept many request and then make my model inference concurrently, I test this on single GPU is ok, but as long as I usetensor_parallel
lib, it will comes an error:tensor_parallel/cross_device_ops.py", line 78, in forward inputs = tuple(map(torch.Tensor.contiguous, inputs)) TypeError: descriptor 'contiguous' for 'torch._C._TensorBase' objects doesn't apply to a 'NoneType' object
, it also raise this exception:TypeError: Caught TypeError in replica 0 on device 0.
, what shall I do? Can you analysis the problem?