vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.9k stars 4.7k forks source link

[RFC]: hide continuous batching complexity through forward context #9098

Open youkaichao opened 1 month ago

youkaichao commented 1 month ago

Motivation.

take a look at the current llama forward computation logic:

class LlamaMLP(nn.Module):
    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

class LlamaAttention(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output

class LlamaDecoderLayer(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual

class LlamaModel(nn.Module):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        hidden_states = self.get_input_embeddings(input_ids)
        residual = None

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                kv_caches[i - self.start_layer],
                attn_metadata,
                residual,
            )

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

class LlamaForCausalLM(nn.Module, SupportsLoRA):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        model_output = self.model(input_ids, positions, kv_caches,
                                  attn_metadata)
        return model_output

if we don't consider attn_metadata and kv_caches, it can be simplified as:

class LlamaMLP(nn.Module):
    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x

class LlamaAttention(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

class LlamaDecoderLayer(nn.Module):
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual

class LlamaModel(nn.Module):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        hidden_states = self.get_input_embeddings(input_ids)
        residual = None

        for i in range(self.start_layer, self.end_layer):
            layer = self.layers[i]
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )

        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

class LlamaForCausalLM(nn.Module):
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> torch.Tensor:
        model_output = self.model(input_ids, positions)
        return model_output

Arguably, attn_metadata is the most complicated part in the forward computation logic. And it becomes even more complicated when we consider:

Therefore, I'm considering to hide the complexity of continuous batching through forward context. The idea is to have a global forward context, which can be set by the model runner during every forward pass. The forward context can be used to store the attention metadata, and the model can access the attention metadata through the forward context.

Proposed Change.

The changes are:

see https://github.com/vllm-project/vllm/pull/9029 and https://github.com/vllm-project/vllm/pull/9097 for initial steps.

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

youkaichao commented 1 month ago

the model runner will set the forward context before running the model, and the forward context will be used to store the attention metadata and kvcache.

there's one alternative: the top-level model owns and sets the forward context. In the llama case, LlamaForCausalLM sets the forward context.

this approach works better for encoder-decoder models, or multi-modality models.

simon-mo commented 1 month ago

I'm a bit worried about the lifetime of the forward context. As in, will it be immutable until the forward pass finishes? In the case of asynchronous scheduling, when can this context be updated? Do we intend to perform the update right before dispatch? How about multistep scheduling?

youkaichao commented 1 month ago

I'm a bit worried about the lifetime of the forward context

the lifetime is the same as the model forward. You can treat it as scratch space for the model. The model can feel free to modify it, but it will be destroyed after model forward.

In the case of asynchronous scheduling, when can this context be updated?

Note that the lifetime of the forward context is the same as the model forward, and we never execute two model forward passes concurrently. All the scheduler, model runner logic are untouched. They can do whatever they want, just as the current codebase.