vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.1k stars 4.55k forks source link

[Feature]: Support a `flush_cache` API to clean the kvcache after `load_weights` #9744

Open zhuzilin opened 2 weeks ago

zhuzilin commented 2 weeks ago

šŸš€ The feature, motivation and pitch

When using vllm to generate rollout in typical rlhf training (e.g. as in OpenRLHF/OpenRLHF), we need to reload the weight of each vllm served model after each rollout training round. And on the other hand, many RLHF algorihms need to sample multiple responses from a given prompt (e.g. GRPO, RLOO and variant of PPO), which makes prefix caching an important feature for the performance of RLHF.

However, in the current vllm impl, the cached prefix won't be updated after load_weights, making the new model running with the old cached prefix. It will be great if we can support an API to invalid the old kv cache or automatically do that within load_weights.

A reference implementation would be sglang/srt/managers/scheduler.py#L1143:

    def update_weights(self, recv_req: UpdateWeightReqInput):
        """In-place update of the weights."""
        success, message = self.tp_worker.update_weights(recv_req)
        if success:
            flash_cache_success = self.flush_cache()
            assert flash_cache_success, "Cache flush failed after updating weights"
        else:
            logger.error(message)
        return success, message

Thank you for your time on this feature request :)

Alternatives

No response

Additional context

No response

Before submitting a new issue...

simon-mo commented 2 weeks ago

Thank you for this feature request. I think it is indeed needed. It looks like OpenRLHF calls load_weights directly by wrapping the worker https://github.com/OpenRLHF/OpenRLHF/blob/f631ebe865699fe9caa12de117eba999d5aa2372/openrlhf/trainer/ray/vllm_worker_wrap.py#L41

Clearing the cache a bit complex to do but I can share some pointers here if anyone would like to take a stab: