vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
25.62k stars 3.71k forks source link

[Feature Request] Way to specify GPU ordinal #3172

Open starmpcc opened 5 months ago

starmpcc commented 5 months ago

Hello.

I am currently employing the vllm library alongside dataparallel for my projects. Up until version 0.2.6, it was feasible to designate specific GPUs for each worker explicitly, which was instrumental for optimizing resource allocation, particularly when not using tensor parallelism. The code snippet below illustrates how this configuration was implemented:

def child_process(rank):
    os.environ["LOCAL_RANK"] = str(rank)
    from vllm import LLM

This functionality seems to be unsupported in versions of vllm later than 0.2.6. This feature is crucial for achieving higher throughput, especially when working with smaller models on high-VRAM GPUs (e.g., a 2B parameter model on an A100 80G GPU).

Thank you!

simon-mo commented 5 months ago

I believe the environment variable PyTorch use internally is CUDA_VISIBILE_DEVICES does that achieve the same effect?

starmpcc commented 5 months ago

Oh, I forgot to adding a detail. The CUDA_VISIBLE_DEVICES is usually work, however, if some library initialized torch (e.g., import torch or import transformers), then it does not work.

Actually, the above snippet should be modified as:

import torch

def child_process(rank):
    os.environ["LOCAL_RANK"] = str(rank)
    from vllm import LLM
simon-mo commented 5 months ago

Interesting. You are saying that once torch is imported, the cuda device is assigned?

starmpcc commented 5 months ago

Yes, more precisely, doing some gpu operation makes effect. The below is the minimum script to reproduce:

Case A (works well)

import os
from multiprocessing import Pool, set_start_method

set_start_method("spawn", force=True)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

import torch

def child(rank):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
    from vllm import LLM

    llm = LLM("/nfs_data_storage/llama2_hf/Llama-2-7b-chat-hf")

def main():
    p = Pool(2)
    p.map(child, range(2))

if __name__ == "__main__":
    main()

Output A (expected)

[0] NVIDIA A100-SXM4-80GB | 29°C,   0 % | 13535 / 81920 MB |
[1] NVIDIA A100-SXM4-80GB | 31°C,   0 % | 13535 / 81920 MB |

Case B (Doesn't work)

import os
from multiprocessing import Pool, set_start_method

set_start_method("spawn", force=True)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

import torch

torch.tensor(1).to("cuda:0")

def child(rank):
    os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
    from vllm import LLM

    llm = LLM("/nfs_data_storage/llama2_hf/Llama-2-7b-chat-hf")

def main():
    p = Pool(2)
    p.map(child, range(2))

if __name__ == "__main__":
    main()

Output B (unexpected)

[0] NVIDIA A100-SXM4-80GB | 29°C,   0 % | 27440 / 81920 MB |
[1] NVIDIA A100-SXM4-80GB | 30°C,   0 % |     7 / 81920 MB |

The above my solution using LOCAL RANK also works in second case, but only operates with vllm<=0.2.6

kaifronsdal commented 2 months ago

I've also run into this limitation a bunch where I'd like to run mutiple models in the same script on different GPUs. Ideally there would be a way to specify the GPUs a model should be located on for each LLM(...) initialization. Setting CUDA_VISIBILE_DEVICES works great for a single model (and can be done programatically by setting os.environ["CUDA_VISIBLE_DEVICES"] before importing torch) but doesn't work for multiple models in the same process.

There was some discussion of a fix here but it would need to be implemented into a release.