Open PtttCode opened 1 year ago
I used your code as motivation for my implementation which is rather similar. In my case I needed to make Llama 2 work with SQS polling. The problem was that every worker process needs to execute the same code when a message happens for a result to be generated. I used your implementation as motivation so thanks for sharing it.
Here's the implementation in case it helps anyone else:
import torch.distributed as dist
queue_listener_thread = None
def start_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, process_message_callback):
global queue_listener_thread
queue_listener_thread = threading.Thread(
target=consume_with_retry, # This is a local library to wrap the SQS interaction but essentially it just polls for messages and calls the first argument when there's something to process
args=(
process_message_callback,
[sqs_input_queue_name],
),
)
queue_listener_thread.daemon = True
queue_listener_thread.start()
def process_sqs_message_with_multiple_gpus(job_dict, process_fn, after_processing_callback):
log_info(f"GPU 0 starting broadcast with {job_dict}.")
dist.broadcast_object_list([job_dict])
log_info("GPU 0 broadcast done.")
process_fn(job_dict, after_processing_callback, 0)
def start_multi_gpu_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, enable_task_protection, process_fn):
if dist.get_rank() == 0:
log_info("GPU 0 starting to listen to SQS messages.")
start_sqs_listener(
sqs_input_queue_name,
sqs_output_queue_name,
enable_task_protection,
lambda job_dict,
after_processing_callback: process_sqs_message_with_multiple_gpus(job_dict, process_fn, after_processing_callback)
)
else:
while True:
log_info(f"GPU {dist.get_rank()} starting to listen for broadcast.")
config = [None] * 1
dist.broadcast_object_list(config)
log_info(f"GPU {dist.get_rank()} received the broadcast.")
process_fn(config[0], lambda a: None, dist.get_rank())
# _do_work is the one who calls chat_completion
if gpu_count > 1: # Multi-GPU
start_multi_gpu_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, _do_work)
else: # Single GPU
start_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, lambda job_dict, after_processing_callback: _do_work(job_dict, after_processing_callback, 0))
I didn't look to it that carefully but is there a reason you're calling generator.generate
instead of generator.chat_completion
for the worker processes? In my case I call chat_complete
for both cases and it works.
I want to make a web service from 70B-chat model, but there are some bugs or errors. here is launch shell:
here is code:
First inference after model built is success, but when i begin to inference second time with
curl -X POST -H "Content-Type: application/json" -d '{"prompts": "give me the result of 55*32 "}' 127.0.0.1:8000
The request is success and enter the generate progress, but the volatile of 8 gpu all up to 100% immediately, and there is no any return after waiting long time.after 1800s, processes are be shutdown:
I guess it is deadlock on parallel computing. But I cant fix it. Or there is any reliable web service code of 70B-chat-model?