vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.96k stars 4.71k forks source link

[Feature]: Allow LoRA adapters to be specified as in-memory dict of tensors #4068

Closed jacobthebanana closed 10 hours ago

jacobthebanana commented 7 months ago

🚀 The feature, motivation and pitch

PPO and a number of other LLM fine-tuning techniques require autoregressive generation as part of the training process. When using vLLM to speed up the autoregressive generation part of the training loop, is there an efficient way to update the weights of the LLM? Specifically, in the case of LoRA fine-tuning, is there a way to efficiently swap out the adapters without having to save them to the filesystem?

Alternatives

Efficient LoRA adapter update

Possible workaround without any code change: save adapters to an in-memory file-system (e.g., /dev/shm) and point to that directory in each LoRARequest. This workaround:

Proposed change: modify LoRARequest to allow adapters to be specified as a dictionary of tensors.

Alternative approach: non-LoRA parameter update

Additional context

LLM fine-tuning objectives such as PPO require autoregressive text generation during training, with the requirement that a reasonably up-to-date copy of the model is used for generation.

As of the v0.4.0 vLLM release, when instantiating a vLLM LoRARequest, the LoRA adapters are specified through the lora_local_path: str attribute. (source) In the LoRA PPO example above, if the vLLM instance is on the same machine as the peft training loop, sending a new copy of the adapter weights to vLLM would require the following steps:

If the proposed alternative is adopted, the new workflow be like:

Related Issues

The idea of adding new LoRA adapters without restarting vLLM is related to #3308 with some differences:

If the changes proposed in this feature request are merged, these features could eventually be added to the vLLM OpenAI-compatible HTTP API to e.g., allow trusted remote users to add LoRA adapters to a vLLM server without first writing the adapters to a filesystem on the server.

vwxyzjn commented 7 months ago

@jacobthebanana that's so cool to know that Open RLHF does something like that. Do you know if there's a minimal example with the weight broadcasting?

jacobthebanana commented 6 months ago

@jacobthebanana that's so cool to know that Open RLHF does something like that. Do you know if there's a minimal example with the weight broadcasting?

A colleague of mine shared this example in the OpenRLHF repository- examples/train_ppo_ray.py. For the use case I'm most interested in, the main drawback of this approach is that it requires a large number of GPUs. Specifically, in this setup, the vLLM engine need to run on its own set of GPUs- separate from the ones that runs backpropagation.

For reference, the OpenRLHF full-rank vLLM hot-swapping logic can be found in openrlhf/trainer/ray/vllm_worker_wrap.py, which is invoked in openrlhf/trainer/ray/ppo_actor.py.

Another challenges is that it is not straightforward to run Torch FSDP alongside vLLM (which uses Ray) on the same set of GPUs. That might become easier when pull request #3466 for vLLM gets merged. My team at work has built a LoRA weight "broadcasting" proof-of-concept based on the changes proposed in that pull request- using the /dev/shm workaround mentioned above. I will be happy to share more about that effort if you are interested.

(Also, apparently my github email notification wasn't set up correctly. Sorry for the delay in replying.)

vwxyzjn commented 6 months ago

That's so cool and thanks for replying! I feel a really impactful project is to train vLLM models directly somehow.

In terms of online RLHF, I also made it possible to place the model on a specific device https://github.com/vwxyzjn/vllm/pull/1 and then apply it to TRL: https://github.com/huggingface/trl/pull/1540. The idea is to load the vLLM model in the 8th GPU and use the remaining GPUs to do training.

jacobthebanana commented 6 months ago

That's so cool and thanks for replying! I feel a really impactful project is to train vLLM models directly somehow.

In terms of online RLHF, I also made it possible to place the model on a specific device vwxyzjn#1 and then apply it to TRL: huggingface/trl#1540. The idea is to load the vLLM model in the 8th GPU and use the remaining GPUs to do training.

That looks like a very elegant way of implementing inference during training and hot-swapping! Indeed- running inference requires far less GPU memory than training, and vLLM further reduces the memory requirement using paged attention. One GPU should be more than enough in terms of the memory it takes to run vLLM Engine.

I have to admit I'm not particularly familiar with HuggingFace accelerate. I see that you've been invoking model.load_weights- do you know if that goes through the CPU memory space? I'm wondering if you've observed a significant throughput limit related to the weight transfer.

Also, do you have a rough estimate of the work it would take to run vLLM on all of these 8 GPUs? My team at work have been looking for ways to make the most of our (limited number of) GPUs by running vLLM on the same devices as the training loop. Because torch FSDP requires exclusive access to nccl, we ended up having to create multiple wrappers around vLLM and our own training logic. Would accelerate (instead of FSDP) be a better choice for enable the training loop to run in parallel with the vLLM logic?

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

github-actions[bot] commented 10 hours ago

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!