vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.02k stars 4.71k forks source link

[RFC]: Initial support for RBLN NPU #7247

Open rebel-jonghewk opened 3 months ago

rebel-jonghewk commented 3 months ago

Motivation.

The RBLN SDK provides a solution for innovative deep learning inference on Rebellion's NPUs, such as ATOM and REBEL, including support for large language models (LLMs). This project aims to develop the RBLN backend for vLLM, initially prioritizing the ATOM device, with future plans to enable REBEL support.

In alignment with Rebellion's Optimum Huggingface extension documentation, RBLN backend will support a wide range of models available in the Rebellion's Model Zoo.

The project currently incorporates continuous batching feature and will soon integrate additional techniques, such as PagedAttention, to enhance performance further.

Proposed Change.

Introduce the RBLN vLLM backend, which will:

Target Models

We will start by ensuring vLLM works with the Llama architecture and expand to other architectures. The full list of LLMs supported by RBLN can be viewed here.

Design

We will introduce several custom classes that align with the vLLM architecture for heterogeneous accelerators (such as Neuron, XPU, TPU...). See the diagram below for details. image

Implementation Details

Initalize model
def init_model(self) -> None:
    config = self.model_config.hf_config
    model_name_or_path = self.model_config.model
    model_name, model_cls_name = get_rbln_model_info(config)

    # huggingface model class
    model_cls = getattr(optimum.rbln, model_cls_name)
    assert model_cls is not None
    # load RBLN compiler binary model
    model = model_cls.from_pretrained(model_name_or_path, export=False)
    self.model = model
Model-specific (e.g. llama specific) forward functions
class RBLNOptimumRBLNLlamaForCausalLM(RBLNBaseLlamaForCausalLM):
    def forward(
        self,
        input_ids: torch.Tensor,
        attn_mask: torch.Tensor,
        positions: torch.Tensor,
        seq_group_metadata_list: List[SequenceGroupMetadata],
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> torch.Tensor:
        is_prefill = seq_group_metadata_list[0].is_prompt
        if not is_prefill:
            input_ids, positions = self.preprocess_decode(
                input_ids, positions, seq_group_metadata_list)
        batch_indices = self.get_batch_indices(seq_group_metadata_list)
        batch_idx = batch_indices[0] if is_prefill else None
        # optimum.rbln RBLNLlamaForCausalLM.forward()
        logits = self.model.forward(input_ids=input_ids.to(torch.int64),
                                    cache_position=positions.to(torch.int32),
                                    batch_idx=batch_idx)
        if not is_prefill:
            logits = self.postprocess_decode(logits, seq_group_metadata_list)
        return logits

References

Feedback Period.

1w

CC List.

@WoosukKwon , @rebel-shshin, @rebel-hekim, @rebel-hongseok

Any Other Things.

rebel-shshin commented 3 months ago

Hello @WoosukKwon @zhuohan123 @simon-mo,

I understand that you may have a lot on your plate, but I would greatly appreciate any feedback or thoughts you might have on this proposal when you have a moment. Your input would be invaluable in helping to refine and move this project forward. :)

github-actions[bot] commented 2 weeks ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

rebel-hongseok commented 1 week ago

Gentle reminder about this RFC - any feedback would be greatly appreciated.