Open Blaizzy opened 1 day ago
You need to run the model inside of set_forward_context
. Example: https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py#L1656
Thanks @DarkLight1337 !
But how do I get the the model_input.attn_metadata
??
It's not very clear how to implementing it on my example.
Thanks @DarkLight1337 !
But how do I get the the
model_input.attn_metadata
??It's not very clear how to implementing it on my example.
This is the tricky part, since it's quite integral to how vLLM operates (e.g. KV cache, prefix caching, chunked prefill...). I guess your custom cache most likely interferes with how vLLM does it by default.
I'm not familiar with this part of the code so can't offer suggestions. Perhaps @comaniac could help?
This is the tricky part, since it's quite integral to how vLLM operates (e.g. KV cache, prefix caching, chunked prefill...). I guess your custom cache most likely interferes with how vLLM does it by default.
Yap, very tricky.
For now, all I want is to know what to pass to the decoder layers:
for layer in model_obj.model.layers[:n_layers]:
attn_output = layer(
positions=position_ids, <--- here
hidden_states=hidden_states,
kv_cache=past_key_values, <--- here
attn_metadata=None, <--- here
residual=residual < -- here
)
We don't support this use case atm so there might be many unexpected behaviors. I'd suggest cloning a model file in vllm, custom register as a plugin model.
Thanks @comaniac!
Could you provide me an example of how to implement it with my use case?
I want to pass the number of layers and a few other arguments at inference time.
For example:
model = LLM(model=model_name, dtype="bfloat16", gpu_memory_utilization=gpu_memory_utilization)
model.generate(inputs, num_layers, arg1, arg2)
Even high level one would help.
Do your args change per request? Or they will be determined when launching the engine.
If it's per request, then yes you need to custom the generate function. For prototyping, I'd suggest directly hack vLLM's .generate() first instead of implementing one outside the core.
They change per request.
Then you may try to change llm.generate()
(https://github.com/vllm-project/vllm/blob/08075c34483843c75b4420bac92377b59ff9a8ac/vllm/entrypoints/llm.py#L295). One quick way I could think of to make it work is adding your arguments to SamplingParams so that they can be passed all the way to the model runner inputs.
Your current environment
How would you like to use vllm
I'm implementating a custom algorithm that requires a custom generate method.
In this method, I need to access and store some of the attention outputs without running a full foward pass whole model as displayed below. But I keep getting errors related to
attn_metadata
. I tried multiple options such as using some of the abstractions inattn_metadata.py
andmodel_runner.py
but with no success.This very easy to do in transformers and I have a working it but I'm struggling to port it to vLLM.
Traceback
Expected result:
Access and store intermediate results of the model directly without having to run a full forward pass.
Before submitting a new issue...