vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
23.11k stars 3.28k forks source link

[Feature]: multi-lora support older nvidia gpus. #6123

Open wuisawesome opened 2 weeks ago

wuisawesome commented 2 weeks ago

🚀 The feature, motivation and pitch

Currently vLLM only supports LoRA adapters on nvidia gpus with compute capability >= 8.0. This request is to support >= 7.5.

The limitation here is that vLLM relies on https://github.com/punica-ai/punica for efficient LoRA and the upstream doesn't support older gpus.

Personally I've mainly run into this problem on Kaggle which requires you to run on T4s or older. Others seem to have run into this problem in other environments. Collab: https://github.com/vllm-project/vllm/issues/5199, other V100s https://github.com/vllm-project/vllm/issues/3826

Alternatives

In some but not all cases this can be mitigated by using a newer gpu or applying the lora to the base model and model swapping.

Additional context

I'm willing to contribute this. I've prototyped this and verified that it's possible to do this efficiently by changing the step of vllm's wheel build which builds the vendored punica kernel.

wuisawesome commented 2 weeks ago

I noticed that when building the vendored punica the issues were all related to bf16 arithmetic operations not being defined in cuda 12.1. Building against a newer cuda version (12.4) which has headers that define these operations fixed the problems.

Note that I'm not sure if building the kernels against cuda 12.4 is desirable/a good engineering practice if we want to support cuda 12.1 still. If that's the case, we can probably vendor the relevant code from cuda (though i don't have a sense of how complicated this would be).

jeejeelee commented 2 weeks ago

5036 are working on addressing the issue you mentioned