OptimalScale / LMFlow

An Extensible Toolkit for Finetuning and Inference of Large Foundation Models. Large Models for All.
https://optimalscale.github.io/LMFlow/
Apache License 2.0
8.11k stars 819 forks source link

[Feature] vllm inferencer and memory safe vllm inferencer #860

Closed wheresmyhair closed 1 week ago

wheresmyhair commented 1 week ago

Description

We perform vllm inferencer and memory safe vllm inferencer, which will benefit online rlhf process. MemorySafeVLLMInferencer runs lmflow/pipeline/utils/memory_safe_vllm_inference.py using python subprocess, since it's not able to offload model or release GPU that vLLM takes within a python script using del, model.to('cpu') or other approaches currently. (see this issue)

Tests

MemorySafeVLLMInferencer

  1. runtime image

  2. test result image

Compatibility

  1. run_reward_modeling.sh rm_res

  2. run_finetune.sh finetune_res

  3. run_finetune_with_lora.sh lora_res

wheresmyhair commented 1 week ago

Changes made, test to be done.

requirements.txt

  • [Feature] line 8: deepspeed <= 0.14.0 to ensure backward compatibility.

src/lmflow/args.py

  • [Style] line 15: we majorly sort imported packages alphabetically. Moving this to line 16 would be better.
  • [Architecture] line 99, 318: the implication of this argument load_on_init seems confusing to users.
  • [Architecture] line 318-335: these arguments belong to Inferencer, not Model. Should move them to InferencerArguments. If model need these arguments, they can be passed in as **kwargs.
  • [Style] line 949-1001: if these options are for vllm only, better append a prefix vllm_. Or implementing the features corresponding to those arguments is another option.
  • [Feature] line 976: better automatically detect os.environ[CUDA_VISIBLE_DEVICES].

if name == "main": cmd = "python /vol/yizhenjia/projs/LMFlow/runs/LMFlow-devtools/subpro_test/cudaenv.py"

run_res = subprocess.run(
    args=cmd,
    stdout=sys.stdout,
    stderr=sys.stderr,
    shell=True,
)
And in `cudaenv.py`:
```python
import torch
import subprocess
import time

if __name__ == "__main__":
    subprocess.run("echo $CUDA_VISIBLE_DEVICES", shell=True) # this prints '1,2'
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    a = torch.Tensor([1]*100000000).to('cuda:1') # and this goes to gpu 2 
    time.sleep(10)

Which results in: image

src/lmflow/models/hf_decoder_model.py

  • [Architecture] line 377, 429, 471: add argument use_vllm, which is passed from Inferencer.

src/lmflow/models/hf_model_mixin.py

  • [Architecture] line 111: pass from Inferencer, specify this as an extra argument for __init__.
  • [Architecture] line 368-419: The indentation level is too high now, consider wrap this part of code in a separated function.
  • [Architecture] line 453: LLM should not be self.backend_model, should have another variable, such as self.backend_model_for_inference, otherwise it will mess up with other usages with self.backend_model.
  • [Question] line 453: Does vllm support dynamic model change during inference?

sampling_params = SamplingParams()

if name == "main": llm = LLM( model='/home/yizhenjia/.cache/huggingface/hub/models--Qwen--Qwen2-0.5B/snapshots/ff3a49fac17555b8dfc4db6709f480cc8f16a9fe', tensor_parallel_size=1, gpu_memory_utilization=0.95, ) res = llm.generate("hi", sampling_params) print(res) time.sleep(10) print('change model') llm = LLM( "meta-llama/Meta-Llama-3-8B-Instruct", ) res = llm.generate("hi", sampling_params) print(res) print('finish')


This results in:
![image](https://github.com/OptimalScale/LMFlow/assets/79436959/ab5cd00a-d515-409a-a81a-5c977580b6f9)

### `src/lmflow/pipeline/inferencerv2.py`
> * We can rename it as `vllm_inferencer.py`. This matches the classname. Also, v2 is vague and confusing.

- ✅ 

### `src/lmflow/pipeline/utils/collections.py`
> * [Style] Better rename it. The name `collection` is vague and confusing.

- ✅ Removed, since functions are moved to other modules.

> * [Architecture] line 15: This is util function for models, move it to `src/lmflow/utils/model.py`. `src/lmflow/pipeline/utils/` are majorly for customized training classes such as `raft_trainer`.

- ✅ 

> * [Architecture] line 28: This is util function for datasets, move it to `src/lmflow/utils/dataset.py`.

- ✅  Moved to `src/lmflow/utils/args.py`, this function takes a list of dataclass object and parses them into shell command format (like `--arg1 value1 --arg2 value2`). The function is used in `MemorySafeVLLMInferencer` since it will run command using subprocess.

### `src/lmflow/pipeline/utils/memory_safe_vllm_inference.py`
> * [Arcthecture] Move it to `examples/memory_safe_vllm_inference.py`, or make it a special mode of the common inference, like a mode that can be activated by providing a single option of `--use_vllm`.

- ❓ Here comes the tricky part. This work as a module that supports `MemorySafeVLLMInferencer`, rather than just a workflow. When use `MemorySafeVLLMInferencer`, `.inference()` method will run this in subprocess and return its result. This is only a workaround due to the VLLM in-python memory releasing issue. 

### `src/lmflow/utils/collections.py`
> * [Architecture] Move the content to `src/lmflow/utils/dataset.py`

- ✅  Removed `src/lmflow/utils/collections.py`. Moves `create_copied_dataclass()` and `remove_dataclass_attr_prefix()` to `src/lmflow/utils/args.py`. They are useful when there are two or more models to load through cli command (ppo, for example, requires reward model and sft model at the same time).  

> ### `tests/pipeline/test_memory_safe_vllm_inferencer.py`
> * [Style] line 16, 23, 34: there are absolute paths, consider uploading the dataset and use huggingface model names.

- ✅ 
wheresmyhair commented 1 week ago

Tests after arch change

MemorySafeVLLMInferencer

  1. runtime image

  2. test result image

Compatibility

  1. run_reward_modeling.sh image

  2. run_finetune.sh image

  3. run_finetune_with_lora.sh image