vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
28.91k stars 4.3k forks source link

[Usage]: Custom LLM Generate #9551

Open Blaizzy opened 1 day ago

Blaizzy commented 1 day ago

Your current environment

The output of `python collect_env.py`

How would you like to use vllm

I'm implementating a custom algorithm that requires a custom generate method.

In this method, I need to access and store some of the attention outputs without running a full foward pass whole model as displayed below. But I keep getting errors related to attn_metadata. I tried multiple options such as using some of the abstractions in attn_metadata.py and model_runner.py but with no success.

This very easy to do in transformers and I have a working it but I'm struggling to port it to vLLM.

import torch
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from typing import List, Tuple, Dict, Any, Optional

model = LLM(model=model_name, dtype="bfloat16", gpu_memory_utilization=gpu_memory_utilization)
model_obj = model.llm_engine.model_executor.driver_worker.model_runner.model
hidden_states = model_obj.model.get_input_embeddings(input_ids)

attention_outputs = []
cache_position = None
past_key_values = None
position_ids = None
n_layers = 2

for layer in model_obj.model.layers[:n_layers]:
    if cache_position is None:
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1]).to(self.device)

    if position_ids is None:
        position_ids = cache_position.unsqueeze(0)

    # Store the input as the residual
    residual = hidden_states

    # Pass attn_metadata and residual to the layer
    layer_output = layer(
        positions=position_ids,
        hidden_states=hidden_states,
        kv_cache=past_key_values,
        attn_metadata=None,
        residual=residual
    )

    # Unpack the layer output
    hidden_states = layer_output[0]

    if attention_output is not None:
        attention_outputs.append(attention_output.detach().to("cpu"))

Traceback

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[1], line 123
    121 model = Model("../workspace/llama_models/llama-8B-Instruct")
    122 messages = [{"role": "user", "content": "Suppose you are provided a 7L bucket. Also, suppose that, afterwards, the floor of that bucket was removed, and its ceiling was sealed. How much water this bucket can hold?"}]
--> 123 output = model.generate(messages, stream=True)
    124 print(output)

Cell In[1], line 89, in Model.generate(self, messages, max_new_tokens, temperature, do_sample, stream)
     86 input_ids = input_ids.input_ids
     88 # Find the context
---> 89 context_indices = self.find_context(input_ids).to(self.device)
     91 if input_ids.device != self.device:
     92     input_ids.to(self.device)

Cell In[1], line 65, in Model.find_context(self, input_ids)
     64 # Get the hidden states and attention matrices
---> 65 attentions, _ = self.get_attention_first_n_layers(input_ids, self.filter_layers)

Cell In[1], line 43, in Model.get_attention_first_n_layers(self, input_ids, n_layers)
     40 residual = hidden_states
     42 # Pass attn_metadata and residual to the layer
---> 43 layer_output = layer(
     44     positions=position_ids,
     45     hidden_states=hidden_states,
     46     kv_cache=past_key_values,
     47     attn_metadata=None,
     48     residual=residual
     49 )
     51 # Unpack the layer output
     52 hidden_states = layer_output[0]

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py:259](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py#line=258), in LlamaDecoderLayer.forward(self, positions, hidden_states, kv_cache, attn_metadata, residual)
    256 else:
    257     hidden_states, residual = self.input_layernorm(
    258         hidden_states, residual)
--> 259 hidden_states = self.self_attn(positions=positions,
    260                                hidden_states=hidden_states,
    261                                kv_cache=kv_cache,
    262                                attn_metadata=attn_metadata)
    264 # Fully Connected
    265 hidden_states, residual = self.post_attention_layernorm(
    266     hidden_states, residual)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py:189](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/vllm/model_executor/models/llama.py#line=188), in LlamaAttention.forward(self, positions, hidden_states, kv_cache, attn_metadata)
    187 q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    188 q, k = self.rotary_emb(positions, q, k)
--> 189 attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
    190 output, _ = self.o_proj(attn_output)
    191 return output

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1552), in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py#line=1561), in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File [/usr/local/lib/python3.11/dist-packages/vllm/attention/layer.py:100](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/vllm/attention/layer.py#line=99), in Attention.forward(self, query, key, value, kv_cache, attn_metadata, attn_type)
     90 def forward(
     91     self,
     92     query: torch.Tensor,
   (...)
     97     attn_type: AttentionType = AttentionType.DECODER,
     98 ) -> torch.Tensor:
--> 100     return self.impl.forward(query,
    101                              key,
    102                              value,
    103                              kv_cache,
    104                              attn_metadata,
    105                              self._k_scale,
    106                              self._v_scale,
    107                              attn_type=attn_type)

File [/usr/local/lib/python3.11/dist-packages/vllm/attention/backends/flash_attn.py:584](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/vllm/attention/backends/flash_attn.py#line=583), in FlashAttentionImpl.forward(self, query, key, value, kv_cache, attn_metadata, k_scale, v_scale, attn_type)
    580 # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
    581 assert k_scale == 1.0 and v_scale == 1.0, (
    582     "key/v_scale is not supported in FlashAttention.")
--> 584 output = torch.ops.vllm.unified_flash_attention(
    585     query,
    586     key,
    587     value,
    588     self.num_heads,
    589     self.head_size,
    590     self.num_kv_heads,
    591     kv_cache,
    592     self.kv_cache_dtype,
    593     k_scale,
    594     v_scale,
    595     self.scale,
    596     self.sliding_window,
    597     self.alibi_slopes,
    598     self.logits_soft_cap,
    599 )
    601 return output

File [/usr/local/lib/python3.11/dist-packages/torch/_ops.py:1061](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_ops.py#line=1060), in OpOverloadPacket.__call__(self_, *args, **kwargs)
   1059 if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
   1060     return _call_overload_packet_from_python(self_, args, kwargs)
-> 1061 return self_._op(*args, **(kwargs or {}))

File [/usr/local/lib/python3.11/dist-packages/torch/_library/autograd.py:98](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_library/autograd.py#line=97), in make_autograd_impl.<locals>.autograd_impl(keyset, *args, **keyword_only_args)
     97 def autograd_impl(keyset, *args, **keyword_only_args):
---> 98     result = Generated.apply(*args, Metadata(keyset, keyword_only_args))  # type: ignore[attr-defined]
     99     return result

File [/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py:574](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py#line=573), in Function.apply(cls, *args, **kwargs)
    571 if not torch._C._are_functorch_transforms_active():
    572     # See NOTE: [functorch vjp and autograd interaction]
    573     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 574     return super().apply(*args, **kwargs)  # type: ignore[misc]
    576 if not is_setup_ctx_defined:
    577     raise RuntimeError(
    578         "In order to use an autograd.Function with functorch transforms "
    579         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    580         "staticmethod. For more details, please see "
    581         "https://pytorch.org/docs/main/notes/extending.func.html"
    582     )

File /usr/local/lib/python3.11/dist-packages/torch/_library/autograd.py:40, in make_autograd_impl.<locals>.forward(ctx, *args)
     38 keyset = metadata.keyset
     39 kwargs = metadata.keyword_only_args
---> 40 result = op.redispatch(keyset & _C._after_autograd_keyset, *args, **kwargs)
     41 if info._setup_context_fn:
     42     # The Dispatcher will remove args that are equal to their default
     43     # values from (args, kwargs). We're going to add it back so that
   (...)
     50     # their setup_context (along with the rest of their operator
     51     # registrations)
     52     args, kwargs = utils.fill_defaults(op._schema, args, kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/_ops.py:672](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_ops.py#line=671), in OpOverload.redispatch(self_, keyset, *args, **kwargs)
    669 def redispatch(self_, keyset, *args, **kwargs):  # noqa: B902
    670     # use `self_` to avoid naming collide with aten ops arguments that
    671     # are named "self". This way, all the aten ops can be called by kwargs.
--> 672     return self_._handle.redispatch_boxed(keyset, *args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/_library/custom_ops.py:494](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_library/custom_ops.py#line=493), in CustomOpDef._register_to_dispatcher.<locals>.adinplaceorview_impl(keyset, *args, **kwargs)
    492                 autograd.graph.increment_version(v)
    493 with _C._AutoDispatchBelowADInplaceOrView():
--> 494     return self._opoverload.redispatch(
    495         keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
    496     )

File [/usr/local/lib/python3.11/dist-packages/torch/_ops.py:672](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_ops.py#line=671), in OpOverload.redispatch(self_, keyset, *args, **kwargs)
    669 def redispatch(self_, keyset, *args, **kwargs):  # noqa: B902
    670     # use `self_` to avoid naming collide with aten ops arguments that
    671     # are named "self". This way, all the aten ops can be called by kwargs.
--> 672     return self_._handle.redispatch_boxed(keyset, *args, **kwargs)

File [/usr/local/lib/python3.11/dist-packages/torch/_library/custom_ops.py:236](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/torch/_library/custom_ops.py#line=235), in CustomOpDef.register_kernel.<locals>.inner.<locals>.backend_impl(*args, **kwargs)
    228 def backend_impl(*args, **kwargs):
    229     # Checks the assumption that outputs cannot alias
    230     # inputs or other outputs.
    231     storages = {
    232         id(tensor.untyped_storage())
    233         for tensor in iter_tensors(args, kwargs)
    234     }
--> 236     result = self._backend_fns[device_type](*args, **kwargs)
    238     tuple_result = result
    239     if not isinstance(result, tuple):

File [/usr/local/lib/python3.11/dist-packages/vllm/attention/backends/flash_attn.py:624](http://157.157.221.29:17975/lab/tree/workspace/usr/local/lib/python3.11/dist-packages/vllm/attention/backends/flash_attn.py#line=623), in unified_flash_attention(query, key, value, num_heads, head_size, num_kv_heads, kv_cache, kv_cache_dtype, k_scale, v_scale, softmax_scale, window_size, alibi_slopes, logits_soft_cap)
    604 @torch.library.custom_op("vllm::unified_flash_attention",
    605                          mutates_args=["kv_cache"])
    606 def unified_flash_attention(
   (...)
    620     logits_soft_cap: Optional[float] = None,
    621 ) -> torch.Tensor:
    623     current_metadata = get_forward_context()
--> 624     assert current_metadata is not None
    625     assert isinstance(current_metadata, FlashAttentionMetadata)
    626     attn_metadata: FlashAttentionMetadata = current_metadata

AssertionError:

Expected result:

Access and store intermediate results of the model directly without having to run a full forward pass.

Before submitting a new issue...

DarkLight1337 commented 1 day ago

You need to run the model inside of set_forward_context. Example: https://github.com/vllm-project/vllm/blob/main/vllm/worker/model_runner.py#L1656

Blaizzy commented 1 day ago

Thanks @DarkLight1337 !

But how do I get the the model_input.attn_metadata??

It's not very clear how to implementing it on my example.

DarkLight1337 commented 1 day ago

Thanks @DarkLight1337 !

But how do I get the the model_input.attn_metadata??

It's not very clear how to implementing it on my example.

This is the tricky part, since it's quite integral to how vLLM operates (e.g. KV cache, prefix caching, chunked prefill...). I guess your custom cache most likely interferes with how vLLM does it by default.

DarkLight1337 commented 1 day ago

I'm not familiar with this part of the code so can't offer suggestions. Perhaps @comaniac could help?

Blaizzy commented 1 day ago

This is the tricky part, since it's quite integral to how vLLM operates (e.g. KV cache, prefix caching, chunked prefill...). I guess your custom cache most likely interferes with how vLLM does it by default.

Yap, very tricky.

For now, all I want is to know what to pass to the decoder layers:

for layer in model_obj.model.layers[:n_layers]:
      attn_output = layer(
              positions=position_ids, <--- here
              hidden_states=hidden_states,
              kv_cache=past_key_values, <--- here
              attn_metadata=None, <--- here
              residual=residual < -- here
          )
comaniac commented 1 day ago

We don't support this use case atm so there might be many unexpected behaviors. I'd suggest cloning a model file in vllm, custom register as a plugin model.

Blaizzy commented 1 day ago

Thanks @comaniac!

Could you provide me an example of how to implement it with my use case?

I want to pass the number of layers and a few other arguments at inference time.

For example:

model = LLM(model=model_name, dtype="bfloat16", gpu_memory_utilization=gpu_memory_utilization)

model.generate(inputs, num_layers, arg1, arg2) 

Even high level one would help.

comaniac commented 1 hour ago

Do your args change per request? Or they will be determined when launching the engine.

If it's per request, then yes you need to custom the generate function. For prototyping, I'd suggest directly hack vLLM's .generate() first instead of implementing one outside the core.

Blaizzy commented 1 hour ago

They change per request.

comaniac commented 56 minutes ago

Then you may try to change llm.generate() (https://github.com/vllm-project/vllm/blob/08075c34483843c75b4420bac92377b59ff9a8ac/vllm/entrypoints/llm.py#L295). One quick way I could think of to make it work is adding your arguments to SamplingParams so that they can be passed all the way to the model runner inputs.