Open zhuzilin opened 2 weeks ago
Thank you for this feature request. I think it is indeed needed. It looks like OpenRLHF calls load_weights directly by wrapping the worker https://github.com/OpenRLHF/OpenRLHF/blob/f631ebe865699fe9caa12de117eba999d5aa2372/openrlhf/trainer/ray/vllm_worker_wrap.py#L41
Clearing the cache a bit complex to do but I can share some pointers here if anyone would like to take a stab:
self.cache_engine
. BlockManagerV2
which is inside LLMEngine
. Just clearing this should be sufficient.
š The feature, motivation and pitch
When using vllm to generate rollout in typical rlhf training (e.g. as in OpenRLHF/OpenRLHF), we need to reload the weight of each vllm served model after each rollout training round. And on the other hand, many RLHF algorihms need to sample multiple responses from a given prompt (e.g. GRPO, RLOO and variant of PPO), which makes prefix caching an important feature for the performance of RLHF.
However, in the current vllm impl, the cached prefix won't be updated after
load_weights
, making the new model running with the old cached prefix. It will be great if we can support an API to invalid the old kv cache or automatically do that withinload_weights
.A reference implementation would be sglang/srt/managers/scheduler.py#L1143:
Thank you for your time on this feature request :)
Alternatives
No response
Additional context
No response
Before submitting a new issue...