vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.08k stars 3.82k forks source link

[Feature]: Precise model device placement #6189

Open vwxyzjn opened 2 months ago

vwxyzjn commented 2 months ago

πŸš€ The feature, motivation and pitch

Hi all, I was wondering if it's possible to do precise model device placement. For example, I would like to place the vLLM model on GPU 1 and let GPU 0 do other things. Being able to do precise model device placement will help unblock online RLHF work in our Hugging Face's TRL, because we want to leverage the fast speed of vLLM's generation.

In particular, we'd like to run training on 7 GPUs, and leave only 1 GPU for vLLM inference. I have a very crude hack that supports this at https://github.com/vwxyzjn/vllm/pull/1, but I figure more general support in vLLM will be more helpful.

Currently this is not possible because the following code will error out

from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="gpt2", tensor_parallel_size=1, device="cuda:1")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
image

Alternatives

No response

Additional context

No response

DarkLight1337 commented 2 months ago

Possibly a stupid question, but have you considered setting CUDA_VISIBLE_DEVICES via shell when running the vLLM script?

vwxyzjn commented 2 months ago

That was the first thing I tried. What happens is that the script cannot see the training GPUs

DarkLight1337 commented 2 months ago

Why does the vLLM inference script need to see the training GPUs?

vwxyzjn commented 2 months ago

I would like to run training and inference in the same script, so I can easily load the online trained weights to vLLM more easily (unless there is another way to doing it more elegantly).

import time

import torch
from accelerate import Accelerator
from accelerate.state import PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm import SamplingParams, SingleGPULLM

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
tok = AutoTokenizer.from_pretrained("vwxyzjn/ppo_zephyr7")
prompt_ids = tok.batch_encode_plus(prompts)["input_ids"]
accelerator = Accelerator(gradient_accumulation_steps=2)
state = PartialState()
llm = AutoModelForCausalLM.from_pretrained(
    "HuggingFaceH4/mistral-7b-sft-beta")
llm = llm.to(accelerator.device)
accelerator.print(f"{torch.cuda.device_count()=}")
if state.is_main_process:
    sampling_params = SamplingParams(temperature=0.001, top_p=1.0)
    inference_llm = SingleGPULLM(model="vwxyzjn/ppo_zephyr7",
                       tensor_parallel_size=1,
                       device="cuda:7")
    llmp = inference_llm.llm_engine.model_executor.driver_worker.model_runner.model
    print(f"πŸ”₯πŸ”₯πŸ”₯ vllm lives in {llmp.lm_head.weight.device}")
    print("prepare to generate")
    outputs = inference_llm.generate(prompt_token_ids=prompt_ids,
                           sampling_params=sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

    print("πŸ”₯πŸ”₯πŸ”₯ Loading weights using shared memory;"
          "we expect the generations to be completely different")
    start_time = time.time()
    llmp.load_weights(llm.named_parameters())
    print(f"Time to load weights: {time.time() - start_time:.2f} seconds")
    outputs = inference_llm.generate(prompt_token_ids=prompt_ids,
                           sampling_params=sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
else:
    # llm.forward
    # llm.backward()
    print("I'm waiting for the main process to generate...")
accelerator.wait_for_everyone()
DarkLight1337 commented 2 months ago

Hmm... @youkaichao any thoughts on this?

youkaichao commented 2 months ago

it is not possible until we separate driver process and tp rank 0 process. currently they live in the same process as the users' process.

DZ9 commented 2 months ago

How about launch a separate vllm server for doing this? It would be much easier and flexible I think.

vwxyzjn commented 2 months ago

Conceptually that works! My question in this case is how would you load the weights of the model efficiently. With the current pipeline I have, loading a 7B model takes 0.01 sec (but maybe it’s just because PT is doing async copy)

DZ9 commented 1 month ago

Very appreciate for you great work! I've been investigated this vllm generation for a while and I'm aware of your concern about efficiency. Right now the only opensourced solution I found for this is implemented in OpenRLHF via broadcasting params to vllm_engine through ray cluster[here]. But running a ray cluster is heavy. Very pleased to find that you are trying to solve it in a more simplicity way. Hope you can finally find a workaround for this! Thanks again for the epic work you've done!

joanvelja commented 5 days ago

@vwxyzjn Did you find a solution to this? I sent you a twitter DM too, since I would love to do the same for a multi-agent RL pipeline I have going on

vwxyzjn commented 3 days ago

Yes. You can do monkey patch like here https://github.com/allenai/open-instruct/blob/online-trainers/open_instruct/vllm_utils.py. Then you can do stuff like https://github.com/allenai/open-instruct/blob/5641385b1bec87d80b61bb219325be7fecac71c3/open_instruct/online_dpo_vllm.py#L360-L368