meta-llama / llama

Inference code for Llama models
Other
56.24k stars 9.55k forks source link

llama-2-70B-chat cannot inference again, multi-gpu volatile all 100% #468

Open PtttCode opened 1 year ago

PtttCode commented 1 year ago

I want to make a web service from 70B-chat model, but there are some bugs or errors. here is launch shell:

torchrun --nproc_per_node 8 example_chat_completion.py \
    --ckpt_dir llama-2-70b-chat/ \
    --tokenizer_path tokenizer.model \
    --max_seq_len 512 --max_batch_size 4

here is code:


from typing import Optional, List
import torch.distributed as dist
from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel

app = FastAPI()

system_prompt = {
                "role": "system",
                "content": "Always answer with Chinese"
            }

dialogs = [
        [
            system_prompt,
            {
                "role": "user",
                "content": "write a 500 words story"
            }
        ]
    ]

def main(
    ckpt_dir: str,
    tokenizer_path: str,
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_seq_len: int = 512,
    max_batch_size: int = 4,
    max_gen_len: Optional[int] = None,
):
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )

    return generator

def completion(): 
    results = generator.chat_completion(
        dialogs,  # type: ignore
        max_gen_len=512,
        temperature=0.6,
        top_p=0.9,
    )

    for dialog, result in zip(dialogs, results):
        for msg in dialog:
            print(f"{msg['role'].capitalize()}: {msg['content']}\n")
        print(
            f"> {result['generation']['role'].capitalize()}: {result['generation']['content']}"
        )
        print("\n==================================\n")

class PromptItem(BaseModel):
    role: str
    content: str

class Config(BaseModel):
    prompts: str = ''
    max_gen_len: int=512
    temperature: float = 0.6
    top_p: float = 0.9

if __name__ == "__main__":
    import fire
    import json

    from llama import Llama
    generator = fire.Fire(main)

    # First inference
    completion()

    if dist.get_rank() == 0:
        @app.post("/")
        def generate(config: Config):
            dist.broadcast_object_list([config.prompts, config.max_gen_len, config.temperature, config.top_p])
            # json_data = config.model_dump()
            req_data = [[system_prompt, {'role': 'user', 'content': config.prompts}]]
            results = generator.chat_completion(
                req_data, max_gen_len=config.max_gen_len, temperature=config.temperature, top_p=config.top_p
            )

            return {"responses": results}

        uvicorn.run(app, host="127.0.0.1")
    else:
        while True:
            config = [None] * 4
            try:
                dist.broadcast_object_list(config)
                generator.generate(
                    config[0], max_gen_len=config[1], temperature=config[2], top_p=config[3]
                )
            except:
                pass

First inference after model built is success, but when i begin to inference second time with curl -X POST -H "Content-Type: application/json" -d '{"prompts": "give me the result of 55*32 "}' 127.0.0.1:8000 The request is success and enter the generate progress, but the volatile of 8 gpu all up to 100% immediately, and there is no any return after waiting long time.

image

after 1800s, processes are be shutdown: image

I guess it is deadlock on parallel computing. But I cant fix it. Or there is any reliable web service code of 70B-chat-model?

RoopeHakulinen commented 1 year ago

I used your code as motivation for my implementation which is rather similar. In my case I needed to make Llama 2 work with SQS polling. The problem was that every worker process needs to execute the same code when a message happens for a result to be generated. I used your implementation as motivation so thanks for sharing it.

Here's the implementation in case it helps anyone else:

import torch.distributed as dist

queue_listener_thread = None

def start_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, process_message_callback):
    global queue_listener_thread

    queue_listener_thread = threading.Thread(
        target=consume_with_retry, # This is a local library to wrap the SQS interaction but essentially it just polls for messages and calls the first argument when there's something to process
        args=(
            process_message_callback,
            [sqs_input_queue_name],
        ),
    )
    queue_listener_thread.daemon = True
    queue_listener_thread.start()

def process_sqs_message_with_multiple_gpus(job_dict, process_fn, after_processing_callback):
    log_info(f"GPU 0 starting broadcast with {job_dict}.")
    dist.broadcast_object_list([job_dict])
    log_info("GPU 0 broadcast done.")
    process_fn(job_dict, after_processing_callback, 0)

def start_multi_gpu_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, enable_task_protection, process_fn):
    if dist.get_rank() == 0:
        log_info("GPU 0 starting to listen to SQS messages.")
        start_sqs_listener(
            sqs_input_queue_name,
            sqs_output_queue_name,
            enable_task_protection,
            lambda job_dict,
            after_processing_callback: process_sqs_message_with_multiple_gpus(job_dict, process_fn, after_processing_callback)
        )
    else:
        while True:
            log_info(f"GPU {dist.get_rank()} starting to listen for broadcast.")
            config = [None] * 1
            dist.broadcast_object_list(config)
            log_info(f"GPU {dist.get_rank()} received the broadcast.")
            process_fn(config[0], lambda a: None, dist.get_rank())

# _do_work is the one who calls chat_completion
if gpu_count > 1: # Multi-GPU
        start_multi_gpu_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, _do_work)
else: # Single GPU
        start_sqs_listener(sqs_input_queue_name, sqs_output_queue_name, lambda job_dict, after_processing_callback: _do_work(job_dict, after_processing_callback, 0))

I didn't look to it that carefully but is there a reason you're calling generator.generate instead of generator.chat_completion for the worker processes? In my case I call chat_complete for both cases and it works.