vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.33k stars 4.39k forks source link

[New Model]: How to modify WeMM to make it compatible with vllm #6281

Open wenyuzzz opened 3 months ago

wenyuzzz commented 3 months ago

The model to consider.

Thanks to the efforts of the vllm team. Recently, I am preparing to optimize the inference performance of WeMM, with the link provided below.

https://huggingface.co/feipengma/WeMM-Chat-2k-CN

The closest model vllm already supports.

WeMM is based on internlm2.

What's your difficulty of supporting the model you want?

The overall framework starts with modeling_wemm.py, which passes the data to modeling_internlm2.py.

However, the model modeling_internlm2.py replaces the basic linear layer with Plora and adds a mask. The code is available at: WeMM-Chat-2k-CN The code for PLoRAis as follows: `

class PLoRA(nn.Module):

def __init__(self,
             in_features: int,
             out_features: int,
             bias: bool = True,
             device=None,
             dtype=None,
             lora_r=8,
             lora_alpha=16,
             lora_dropout=0.05,
             lora_len=0,
             **kwargs) -> None:
    super().__init__()

    self.original_linear = nn.Linear(in_features, out_features, bias, device, dtype)

    self.lora_r = lora_r
    self.lora_alpha = lora_alpha
    self.lora_len = lora_len
    if lora_dropout > 0.:
        self.lora_dropout = nn.Dropout(p=lora_dropout)
    else:
        self.lora_dropout = lambda x: x
    self.lora_scaling = self.lora_alpha / self.lora_r

    self.Plora_A = nn.Linear(
        in_features, self.lora_r, bias=False, device=device, dtype=dtype)
    self.Plora_B = nn.Linear(
        self.lora_r, out_features, bias=False, device=device, dtype=dtype)

    self.reset_parameters()

def reset_parameters(self):
    if hasattr(self, 'lora_A'):
        # initialize A the same way as the default for nn.Linear and B to zero
        nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B.weight)

def forward(self, x, im_mask=None):
    res = self.original_linear(x)

    if im_mask is not None:
        if torch.sum(im_mask) > 0:
            part_x = x[im_mask]
            res[im_mask] += self.Plora_B(
                self.Plora_A(
                    self.lora_dropout(part_x))) * self.lora_scaling
        else:
            part_x = x[:, :1]
            res[:, :1] += self.Plora_B(
                self.Plora_A(self.lora_dropout(part_x))) * 0
    return res

`

In subsequent code, PLoRA has replaced Linear.

`

class InternLM2MLP(nn.Module): def init(self, config): super().init() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size

self.w1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)

    # self.w3 = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
    # self.w2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)

    self.w1 = PLoRA(
        self.hidden_size,
        self.intermediate_size,
        bias=False,
        lora_r=_PLORA_DIM,
        lora_alpha=_PLORA_DIM,
        lora_len=576)
    self.w3 = PLoRA(
        self.hidden_size,
        self.intermediate_size,
        bias=False,
        lora_r=_PLORA_DIM,
        lora_alpha=_PLORA_DIM,
        lora_len=576)
    self.w2 = PLoRA(
        self.intermediate_size,
        self.hidden_size,
        bias=False,
        lora_r=_PLORA_DIM,
        lora_alpha=_PLORA_DIM,
        lora_len=576)

    self.act_fn = ACT2FN[config.hidden_act]

def forward(self, x, im_mask):
    down_proj = self.w2(self.act_fn(self.w1(x, im_mask)) * self.w3(x, im_mask), im_mask)
    return down_proj

class InternLM2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: InternLM2Config):
    super().__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.head_dim = self.hidden_size // self.num_heads
    self.num_key_value_heads = config.num_key_value_heads
    self.num_key_value_groups = self.num_heads // self.num_key_value_heads
    self.max_position_embeddings = config.max_position_embeddings
    self.is_causal = True

    if (self.head_dim * self.num_heads) != self.hidden_size:
        raise ValueError(
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
            f" and `num_heads`: {self.num_heads})."
        )

    # self.wqkv = nn.Linear(
    #     self.hidden_size,
    #     (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
    #     bias=config.bias,
    # )
    #
    # self.wo = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)

    self.wqkv = PLoRA(
        self.hidden_size,
        (self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
        bias=config.bias,
        lora_r=_PLORA_DIM,
        lora_alpha=_PLORA_DIM,
        lora_len=576)

    self.wo = PLoRA(
        self.num_heads * self.head_dim,
        self.hidden_size,
        bias=config.bias,
        lora_r=_PLORA_DIM,
        lora_alpha=_PLORA_DIM,
        lora_len=576)
    self._init_rope()

def _init_rope(self):
    if self.config.rope_scaling is None:
        self.rotary_emb = InternLM2RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.config.rope_theta,
        )
    else:
        scaling_type = self.config.rope_scaling["type"]
        scaling_factor = self.config.rope_scaling["factor"]
        if scaling_type == "dynamic":
            self.rotary_emb = InternLM2DynamicNTKScalingRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.config.rope_theta,
                scaling_factor=scaling_factor,
            )
        elif scaling_type == "linear":
            self.rotary_emb = InternLM2LinearScalingRotaryEmbedding(
                self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.config.rope_theta,
                scaling_factor=scaling_factor,
            )
        else:
            raise ValueError("Currently we only support rotary embedding's type being 'dynamic' or 'linear'.")
    return self.rotary_emb

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
    return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    im_mask: Optional[Tuple[torch.Tensor]] = None,
    **kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
    if "padding_mask" in kwargs:
        warnings.warn(
            "Passing `padding_mask` is deprecated and will be removed in v4.37. "
            "Please make sure use `attention_mask` instead.`"
        )

    bsz, q_len, _ = hidden_states.size()
    qkv_states = self.wqkv(hidden_states, im_mask)

    qkv_states = rearrange(
        qkv_states,
        "b q (h gs d) -> b q h gs d",
        gs=2 + self.num_key_value_groups,
        d=self.head_dim,
    )

    query_states = qkv_states[..., : self.num_key_value_groups, :]
    query_states = rearrange(query_states, "b q h gs d -> b q (h gs) d")
    key_states = qkv_states[..., -2, :]
    value_states = qkv_states[..., -1, :]

    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    kv_seq_len = key_states.shape[-2]
    if past_key_value is not None:
        kv_seq_len += past_key_value[0].shape[-2]
    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

    if past_key_value is not None:
        # reuse k, v, self_attention
        key_states = torch.cat([past_key_value[0], key_states], dim=2)
        value_states = torch.cat([past_key_value[1], value_states], dim=2)

    past_key_value = (key_states, value_states) if use_cache else None

    key_states = repeat_kv(key_states, self.num_key_value_groups)
    value_states = repeat_kv(value_states, self.num_key_value_groups)

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
        raise ValueError(
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
            f" {attn_weights.size()}"
        )

    if attention_mask is not None:
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
            raise ValueError(
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
            )
        attn_weights = attn_weights + attention_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_output = torch.matmul(attn_weights, value_states)

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
        raise ValueError(
            f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
            f" {attn_output.size()}"
        )

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

    attn_output = self.wo(attn_output, im_mask)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value

`

If you directly replace Linear with PLoRA in this code, whether attention and MLP in the follow-up need to be modified.

`

class PLoRA(nn.Module):

def __init__(self,
             in_features: int,
             out_features: int,
             bias: bool = True,
             device=None,
             dtype=None,
             lora_r=8,
             lora_alpha=16,
             lora_dropout=0.05,
             lora_len=0,
             quant_config: Optional[QuantizationConfig] = None,
             **kwargs) -> None:
    super().__init__()

    # self.original_linear = nn.Linear(in_features, out_features, bias, device, dtype)
    self.original_linear = RowParallelLinear(in_features,
                                out_features,
                                bias=False,
                                params_dtype = dtype,
                                quant_config=quant_config)
    self.lora_r = lora_r
    self.lora_alpha = lora_alpha
    self.lora_len = lora_len
    if lora_dropout > 0.:
        self.lora_dropout = nn.Dropout(p=lora_dropout)

    else:
        self.lora_dropout = lambda x: x
    self.lora_scaling = self.lora_alpha / self.lora_r

    # self.Plora_A = nn.Linear(
    #     in_features, self.lora_r, bias=False, device=device, dtype=dtype)
    self.Plora_A = RowParallelLinear(in_features,
                        self.lora_r,
                        bias=False,
                        params_dtype=dtype,
                        quant_config=quant_config)
    # self.Plora_B = nn.Linear(
    #     self.lora_r, out_features, bias=False, device=device, dtype=dtype)
    self.Plora_B = RowParallelLinear(self.lora_r,
            out_features,
            bias=False,
            params_dtype=dtype,
            quant_config=quant_config)

#     self.reset_parameters()

# def reset_parameters(self):
#     if hasattr(self, 'lora_A'):
#         # initialize A the same way as the default for nn.Linear and B to zero
#         nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5))
#         nn.init.zeros_(self.lora_B.weight)

def forward(self, x, im_mask=None):
    res = self.original_linear(x)

    if im_mask is not None:
        if torch.sum(im_mask) > 0:
            part_x = x[im_mask]
            part_feat,_ = self.Plora_A(self.lora_dropout(part_x))
            part_feat, _ = self.Plora_B(part_feat)
            res[im_mask] += part_feat * self.lora_scaling
        else:
            part_x = x[:, :1]
            part_feat,_ = self.Plora_A(self.lora_dropout(part_x))
            part_feat, _ = self.Plora_B(part_feat)
            res[im_mask] += part_feat * self.lora_scaling
            res[:, :1] += part_feat * 0
    return res

`

Or are there other ways to modify, looking forward to your reply.

github-actions[bot] commented 5 days ago

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!