microsoft / DeepSpeed-MII

MII makes low-latency and high-throughput inference possible, powered by DeepSpeed.
Apache License 2.0
1.87k stars 175 forks source link

Block when Call client inference in multiprocessing.Process #449

Open zhaotyer opened 6 months ago

zhaotyer commented 6 months ago

I tried to integrate mii into tritonserver, but encountered some problems Below is part of my code

class TritonPythonModel:
    def initialize(self, args):
        import mii
        from transformers import AutoTokenizer
        tensor_parallel_size = 1
        cuda_env = env_manager.cuda_visible_devices
        if cuda_env is None:
            from torch.cuda import device_count
            tensor_parallel_size = device_count()
        else:
            tensor_parallel_size = len(cuda_env.split(",")) if cuda_env else 1
        self._model = mii.serve(self.base_model_path, deployment_name="atom", tensor_parallel=tensor_parallel_size)
        self._tokenizer = AutoTokenizer.from_pretrained(self.base_model_path, trust_remote_code=True)

    def execute(self, requests):
        responses = []
        for request in requests:
            self.process_request(request)
        return None

    def process_request(self, request):
            # self.create_task(self.mii_response_thread(request.get_response_sender(), request))
            thread = Process(target=self.mii_response_thread,
                                    args=(request.get_response_sender(), request))
            # thread.daemon = True
            thread.start()

    def mii_response_thread(self, response_sender, request):
        try:
            import mii
            event_loop = asyncio.new_event_loop()
            asyncio.set_event_loop(event_loop)
            req_desc = pb_utils.get_input_tensor_by_name(request, "JSON")
            req_json = json.loads(req_desc.as_numpy()[0])
            stop = req_json.get('stop', False)
            query, prompt, history, stream, gen_config, response_config, tools, mode  = self.process_input_params(request, req_json)
            client = mii.client('atom')
            output_tokens = []
            def callback(response):
                logger.debug(f"Received: {response[0].generated_text}")
                self.send(response_sender, response[0].generated_text)
                # print(f"Received: {response[0].generated_text} time_last_token={time_last_token}")
                output_tokens.append(response[0].generated_text)
            logger.debug("call mii generate")
            client.generate(prompt, max_new_tokens=4096, streaming_fn=callback)
            logger.info(f"output text is:{''.join(output_tokens)}")
        except Exception as e:
            logger.exception(f"Capture error:{e}")
            self.send_error(response_sender, f"Error occur:{e}")
        finally:
            self.send_final(response_sender)
            # self.handler.ongoing_request_count -= 1

the error is: when i use

 thread = Process(target=self.mii_response_thread,
                                    args=(request.get_response_sender(), request))

mii block at

async for response in getattr(self.stub,
                                      task_methods.method_stream_out)(proto_request):
            yield task_methods.unpack_response_from_proto(response)

when i use

 thread = Thread(target=self.mii_response_thread,
                                    args=(request.get_response_sender(), request))

Able to infer normally, but grpc keeps reporting errors(Does not affect inference but the service is not stable) https://github.com/grpc/grpc/issues/25364

nxznm commented 5 months ago

I meet the similar case. Here is my code:

def worker(rank, this_model):
    try:
        if this_model is None:
            client = mii.client('qwen')
        else:
            client = this_model
        response = client.generate(["xxx"], max_new_tokens=1024, stop="<|im_end|>", do_sample=False, return_full_text=True)
        print("in worker rank:", rank, " response:", response)
    except Exception as e:
        print(f"Capture error:{e}")
    finally:
        print("final")

model = mii.serve(model_dir, deployment_name="qwen", tensor_parallel=xx, replica_num=replica_num)

job_process = []
for rank in range(0, replica_num):
    if rank == 0:
        job_process.append(threading.Thread(target=worker,args=(rank,model,)))
    else:
        job_process.append(threading.Thread(target=worker,args=(rank,None,)))
for process in job_process:
    process.start()
for process in job_process:
    process.join()

When using threading.Thread, it works well. However, it will be blocked in client.generate if using multiprocessing.Process.

nxznm commented 5 months ago

I meet the similar case. Here is my code:

def worker(rank, this_model):
    try:
        if this_model is None:
            client = mii.client('qwen')
        else:
            client = this_model
        response = client.generate(["xxx"], max_new_tokens=1024, stop="<|im_end|>", do_sample=False, return_full_text=True)
        print("in worker rank:", rank, " response:", response)
    except Exception as e:
        print(f"Capture error:{e}")
    finally:
        print("final")

model = mii.serve(model_dir, deployment_name="qwen", tensor_parallel=xx, replica_num=replica_num)

job_process = []
for rank in range(0, replica_num):
    if rank == 0:
        job_process.append(threading.Thread(target=worker,args=(rank,model,)))
    else:
        job_process.append(threading.Thread(target=worker,args=(rank,None,)))
for process in job_process:
    process.start()
for process in job_process:
    process.join()

When using threading.Thread, it works well. However, it will be blocked in client.generate if using multiprocessing.Process.

Since the threading.Thread is fake in python due to GIL, this code can not make full use of concurrency. It means that I still need multiprocessing.Process to start a new client. However, it does not work well mentioned above.

nxznm commented 5 months ago

I meet the similar case. Here is my code:

def worker(rank, this_model):
    try:
        if this_model is None:
            client = mii.client('qwen')
        else:
            client = this_model
        response = client.generate(["xxx"], max_new_tokens=1024, stop="<|im_end|>", do_sample=False, return_full_text=True)
        print("in worker rank:", rank, " response:", response)
    except Exception as e:
        print(f"Capture error:{e}")
    finally:
        print("final")

model = mii.serve(model_dir, deployment_name="qwen", tensor_parallel=xx, replica_num=replica_num)

job_process = []
for rank in range(0, replica_num):
    if rank == 0:
        job_process.append(threading.Thread(target=worker,args=(rank,model,)))
    else:
        job_process.append(threading.Thread(target=worker,args=(rank,None,)))
for process in job_process:
    process.start()
for process in job_process:
    process.join()

When using threading.Thread, it works well. However, it will be blocked in client.generate if using multiprocessing.Process.

Since the threading.Thread is fake in python due to GIL, this code can not make full use of concurrency. It means that I still need multiprocessing.Process to start a new client. However, it does not work well mentioned above.

I find the official example. Maybe we should start the server and clients like these ways.