BlackSamorez / tensor_parallel

Automatically split your PyTorch models on multiple GPUs for training & inference
MIT License
629 stars 39 forks source link

Does tensor_parallel support the model inference concurrently or in multi-threads? #86

Closed zoubaihan closed 1 year ago

zoubaihan commented 1 year ago

When I use tensor_parallel to make my model inference on two GPUs(the model is deployed on Flask pywsgi), I also use gunicorn to manage my flask app to make it can accept many request and then make my model inference concurrently, I test this on single GPU is ok, but as long as I use tensor_parallel lib, it will comes an error:tensor_parallel/cross_device_ops.py", line 78, in forward inputs = tuple(map(torch.Tensor.contiguous, inputs)) TypeError: descriptor 'contiguous' for 'torch._C._TensorBase' objects doesn't apply to a 'NoneType' object, it also raise this exception:TypeError: Caught TypeError in replica 0 on device 0., what shall I do? Can you analysis the problem?

linfan commented 1 year ago

Here's a simple piece of code which reproduce the issue

from transformers import AutoModelForCausalLM, AutoTokenizer
import threading
import tensor_parallel as tp
import torch

checkpoint = "bigcode/starcoderbase"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype=torch.float16,low_cpu_mem_usage=True, trust_remote_code=True)
model = tp.tensor_parallel(model, sharded=False)

def task():
    inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda")
    outputs = model.generate(inputs)
    print(tokenizer.decode(outputs[0], clean_up_tokenization_spaces=False))

thread1 = threading.Thread(target=task)
thread2 = threading.Thread(target=task)

thread1.start()
thread2.start()

thread1.join()
thread2.join()
BlackSamorez commented 1 year ago

@zoubaihan sadly, tensor_parallel models can't be used concurrently because otherwise the communications break. For example, the code @linfan provided showcases how it happens when two concurrent calls reach a round of communication both trying to broadcast the the first part of some tensor, overwriting each other and leaving the second part None, but both thinking that the broadcast is complete since two parts were communicated in total. For your specific case I'd recommend to find a library to batch requests and thus utilize the resources more efficiently. But forward calls themselves should not be concurrent.

tldr: no concurrency, use locks