vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
29.52k stars 4.43k forks source link

[BUG] baichuan is not supported #985

Closed janelu9 closed 9 months ago

janelu9 commented 1 year ago

vllm==0.1.5

INFO 09-08 09:55:52 llm_engine.py:72] Initializing an LLM engine with config: model='baichuan2', tokenizer='baichuan2', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.bfloat16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-08 09:55:52 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 1
----> 1 model = LLM(model = "baichuan2", trust_remote_code = True,gpu_memory_utilization = 0.4)

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/entrypoints/llm.py:66, in LLM.__init__(self, model, tokenizer, tokenizer_mode, trust_remote_code, tensor_parallel_size, dtype, seed, **kwargs)
     55     kwargs["disable_log_stats"] = True
     56 engine_args = EngineArgs(
     57     model=model,
     58     tokenizer=tokenizer,
   (...)
     64     **kwargs,
     65 )
---> 66 self.llm_engine = LLMEngine.from_engine_args(engine_args)
     67 self.request_counter = Counter()

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/engine/llm_engine.py:223, in LLMEngine.from_engine_args(cls, engine_args)
    220 distributed_init_method, placement_group = initialize_cluster(
    221     parallel_config)
    222 # Create the LLM engine.
--> 223 engine = cls(*engine_configs,
    224              distributed_init_method,
    225              placement_group,
    226              log_stats=not engine_args.disable_log_stats)
    227 return engine

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/engine/llm_engine.py:105, in LLMEngine.__init__(self, model_config, cache_config, parallel_config, scheduler_config, distributed_init_method, placement_group, log_stats)
    102     self._init_workers(distributed_init_method)
    104 # Profile the memory usage and initialize the cache.
--> 105 self._init_cache()
    107 # Create the scheduler.
    108 self.scheduler = Scheduler(scheduler_config, cache_config)

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/engine/llm_engine.py:185, in LLMEngine._init_cache(self)
    183 """Profiles the memory usage and initializes the KV cache."""
    184 # Get the maximum number of blocks that can be allocated on GPU and CPU.
--> 185 num_blocks = self._run_workers(
    186     "profile_num_available_blocks",
    187     get_all_outputs=True,
    188     block_size=self.cache_config.block_size,
    189     gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
    190     cpu_swap_space=self.cache_config.swap_space_bytes,
    191 )
    193 # Since we use a shared centralized controller, we take the minimum
    194 # number of blocks across all workers to make sure all the memory
    195 # operators can be applied to all workers.
    196 num_gpu_blocks = min(b[0] for b in num_blocks)

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/engine/llm_engine.py:678, in LLMEngine._run_workers(self, method, get_all_outputs, *args, **kwargs)
    675     else:
    676         executor = getattr(worker, method)
--> 678     output = executor(*args, **kwargs)
    679     all_outputs.append(output)
    681 if self.parallel_config.worker_use_ray:

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/worker/worker.py:108, in Worker.profile_num_available_blocks(self, block_size, gpu_memory_utilization, cpu_swap_space)
    106 # Execute the model.
    107 num_layers = self.model_config.get_num_layers(self.parallel_config)
--> 108 self.model(
    109     input_ids=input_tokens,
    110     positions=input_positions,
    111     kv_caches=[(None, None)] * num_layers,
    112     input_metadata=input_metadata,
    113     cache_events=None,
    114 )
    116 # Calculate the number of blocks that can be allocated with the
    117 # profiled peak memory.
    118 torch.cuda.synchronize()

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/model_executor/models/baichuan.py:294, in BaiChuanBaseForCausalLM.forward(self, input_ids, positions, kv_caches, input_metadata, cache_events)
    286 def forward(
    287     self,
    288     input_ids: torch.Tensor,
   (...)
    292     cache_events: Optional[List[torch.cuda.Event]],
    293 ) -> SamplerOutput:
--> 294     hidden_states = self.model(input_ids, positions, kv_caches,
    295                                input_metadata, cache_events)
    296     next_tokens = self.sampler(self.lm_head.weight, hidden_states,
    297                                input_metadata)
    298     return next_tokens

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/model_executor/models/baichuan.py:262, in BaiChuanModel.forward(self, input_ids, positions, kv_caches, input_metadata, cache_events)
    260         cache_event = cache_events[i]
    261     layer = self.layers[i]
--> 262     hidden_states = layer(
    263         positions,
    264         hidden_states,
    265         kv_caches[i],
    266         input_metadata,
    267         cache_event,
    268     )
    269 hidden_states = self.norm(hidden_states)
    270 return hidden_states

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/model_executor/models/baichuan.py:212, in BaiChuanDecoderLayer.forward(self, positions, hidden_states, kv_cache, input_metadata, cache_event)
    210 residual = hidden_states
    211 hidden_states = self.input_layernorm(hidden_states)
--> 212 hidden_states = self.self_attn(
    213     positions=positions,
    214     hidden_states=hidden_states,
    215     kv_cache=kv_cache,
    216     input_metadata=input_metadata,
    217     cache_event=cache_event,
    218 )
    219 hidden_states = residual + hidden_states
    221 # Fully Connected

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/model_executor/models/baichuan.py:171, in BaiChuanAttention.forward(self, positions, hidden_states, kv_cache, input_metadata, cache_event)
    169 k_cache, v_cache = kv_cache
    170 if self.postion_embedding == "ALIBI":
--> 171     attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
    172                             cache_event)
    173 else:
    174     attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
    175                             input_metadata, cache_event)

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/model_executor/layers/attention.py:200, in PagedAttention.forward(self, query, key, value, key_cache, value_cache, input_metadata, cache_event)
    198     assert input_metadata.num_generation_tokens == 0
    199     self.set_attn_bias(input_metadata)
--> 200     self.multi_query_kv_attention(
    201         output[:num_prompt_tokens],
    202         query[:num_prompt_tokens],
    203         key[:num_prompt_tokens],
    204         value[:num_prompt_tokens],
    205         input_metadata,
    206     )
    208 # Wait until the cache op is done.
    209 if cache_event is not None:

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/vllm/model_executor/layers/attention.py:401, in PagedAttentionWithALiBi.multi_query_kv_attention(self, output, query, key, value, input_metadata)
    399 for i, prompt_len in enumerate(input_metadata.prompt_lens):
    400     end = start + prompt_len
--> 401     out = xops.memory_efficient_attention_forward(
    402         query[None, start:end],
    403         key[None, start:end],
    404         value[None, start:end],
    405         attn_bias=input_metadata.attn_bias[i],
    406         p=0.0,
    407         scale=self.scale,
    408     )
    409     # TODO(woosuk): Unnecessary copy. Optimize.
    410     output[start:end].copy_(out.squeeze(0))

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py:214, in memory_efficient_attention_forward(query, key, value, attn_bias, p, scale, op)
    201 def memory_efficient_attention_forward(
    202     query: torch.Tensor,
    203     key: torch.Tensor,
   (...)
    209     op: Optional[Type[AttentionFwOpBase]] = None,
    210 ) -> torch.Tensor:
    211     """
    212     Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`.
    213     """
--> 214     return _memory_efficient_attention_forward(
    215         Inputs(
    216             query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale
    217         ),
    218         op=op,
    219     )

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/xformers/ops/fmha/__init__.py:311, in _memory_efficient_attention_forward(inp, op)
    308 else:
    309     _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp)
--> 311 out, *_ = op.apply(inp, needs_gradient=False)
    312 return out.reshape(output_shape)

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/xformers/ops/fmha/cutlass.py:186, in FwOp.apply(cls, inp, needs_gradient)
    184     raise NotImplementedError("Unsupported attn_bias type")
    185 seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp)
--> 186 out, lse, rng_seed, rng_offset = cls.OPERATOR(
    187     query=inp.query,
    188     key=inp.key,
    189     value=inp.value,
    190     attn_bias=_get_tensor_bias(inp.attn_bias),
    191     seqstart_q=seqstart_q,
    192     seqstart_k=seqstart_k,
    193     max_seqlen_q=max_seqlen_q,
    194     dropout_p=inp.p,
    195     compute_logsumexp=needs_gradient,
    196     custom_mask_type=_custom_mask_type(inp.attn_bias),
    197     scale=inp.scale,
    198     seqlen_k=inp.attn_bias.k_seqinfo.seqlen
    199     if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
    200     else None,
    201 )
    202 ctx: Optional[Context] = None
    203 if needs_gradient:

File /mnt/e/conda-py311-cu118-torch201/lib/python3.11/site-packages/torch/_ops.py:502, in OpOverloadPacket.__call__(self, *args, **kwargs)
    497 def __call__(self, *args, **kwargs):
    498     # overloading __call__ to ensure torch.ops.foo.bar()
    499     # is still callable from JIT
    500     # We save the function ptr as the `op` attribute on
    501     # OpOverloadPacket to access it here.
--> 502     return self._op(*args, **kwargs or {})

RuntimeError: invalid dtype for bias - should match query's dtype
jeffchy commented 1 year ago

same for baichuan1 RuntimeError: invalid dtype for bias - should match query's dtype

janelu9 commented 1 year ago

I add a new module PagedAttentionBaichuan to make sure the alibi is the same as HuggingFace vllm/vllm/model_executor/layers/attention.py:

class PagedAttentionBaichuan(PagedAttentionWithALiBi):
    """PagedAttention with baichuan's ALiBi attention bias."""

    def __init__(self,
                 num_heads: int,
                 head_size: int,
                 scale: float,
                 slopes: List[float],
                 num_kv_heads: Optional[int] = None) -> None:
        super().__init__(num_heads, head_size, scale, num_kv_heads)
        slopes = torch.tensor(slopes, dtype=torch.float32)[:, None, None]
        self.register_buffer("alibi_slopes", slopes, persistent=False)

    def set_attn_bias(self, input_metadata: InputMetadata,
                      dtype: torch.dtype) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.empty(
                1,  # batch_size
                self.num_heads,
                prompt_len,
                (prompt_len + 7) // 8 * 8,
                device=self.alibi_slopes.device,
                dtype=dtype,
            )[:, :, :, :prompt_len].copy_(torch.arange(prompt_len))
            bias.mul_(self.alibi_slopes)
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

vllm/vllm/model_executor/models/baichuan.py#L150:

            self.attn = PagedAttentionBaichuan(self.num_heads, self.head_dim,
                                                scaling, alibi_slopes)

vllm/csrc/attention/attention_kernels.cu#L181:

qk += alibi_slope * token_idx;

I modified the prompt's position id as _0 to promptlen -1 rather than _-promptlen +1 to 0,

jeffchy commented 1 year ago

It works when I change the code for creating the attention bias

https://github.com/vllm-project/vllm/blob/852ef5b4f5481ce526c804ea234d1de0df91f48d/vllm/model_executor/layers/attention.py#L199C1-L199C47

the basic idea is to pass the query dtype into the set_attn_bias method self.set_attn_bias(input_metadata, query.dtype) and change the set_attn_bias method

    def set_attn_bias(self, input_metadata: InputMetadata, dtype: torch.dtype = torch.float32) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                1,  # batch_size
                self.num_heads,
                prompt_len,
                padded_len,
                device=self.alibi_slopes.device,
                dtype=dtype # add this
            )[:, :, :, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)
janelu9 commented 1 year ago

It works when I change the code for creating the attention bias

https://github.com/vllm-project/vllm/blob/852ef5b4f5481ce526c804ea234d1de0df91f48d/vllm/model_executor/layers/attention.py#L199C1-L199C47

the basic idea is to pass the query dtype into the set_attn_bias method self.set_attn_bias(input_metadata, query.dtype) and change the set_attn_bias method

    def set_attn_bias(self, input_metadata: InputMetadata, dtype: torch.dtype = torch.float32) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                1,  # batch_size
                self.num_heads,
                prompt_len,
                padded_len,
                device=self.alibi_slopes.device,
                dtype=dtype # add this
            )[:, :, :, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

Did you find huggingface's alibi is diffrent from vllm's alibi of baichuan? One is increasing from 0, another is increasing to 0.

tianchaolangzi commented 1 year ago

It works when I change the code for creating the attention bias

https://github.com/vllm-project/vllm/blob/852ef5b4f5481ce526c804ea234d1de0df91f48d/vllm/model_executor/layers/attention.py#L199C1-L199C47

the basic idea is to pass the query dtype into the set_attn_bias method self.set_attn_bias(input_metadata, query.dtype) and change the set_attn_bias method

    def set_attn_bias(self, input_metadata: InputMetadata, dtype: torch.dtype = torch.float32) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                1,  # batch_size
                self.num_heads,
                prompt_len,
                padded_len,
                device=self.alibi_slopes.device,
                dtype=dtype # add this
            )[:, :, :, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

it works for me too

jeffchy commented 1 year ago

It works when I change the code for creating the attention bias https://github.com/vllm-project/vllm/blob/852ef5b4f5481ce526c804ea234d1de0df91f48d/vllm/model_executor/layers/attention.py#L199C1-L199C47 the basic idea is to pass the query dtype into the set_attn_bias method self.set_attn_bias(input_metadata, query.dtype) and change the set_attn_bias method

    def set_attn_bias(self, input_metadata: InputMetadata, dtype: torch.dtype = torch.float32) -> None:
        if input_metadata.attn_bias:
            # Already set by a previous layer.
            return
        # Generates ALiBi mask for each prompt.
        for prompt_len in input_metadata.prompt_lens:
            bias = torch.arange(prompt_len)
            # Note(zhuohan): HF uses
            #     `bias = bias[None, :].repeat(prompt_len, 1)`
            # here. We find that both biases give the same results, but
            # the bias below more accurately follows the original ALiBi
            # paper.
            bias = bias[None, :] - bias[:, None]
            bias = bias.to(self.alibi_slopes.device)

            # When using custom attention bias, xformers requires the bias to
            # be sliced from a tensor whose length is a multiple of 8.
            padded_len = (prompt_len + 7) // 8 * 8
            bias = torch.empty(
                1,  # batch_size
                self.num_heads,
                prompt_len,
                padded_len,
                device=self.alibi_slopes.device,
                dtype=dtype # add this
            )[:, :, :, :prompt_len].copy_(bias)
            bias.mul_(self.alibi_slopes[:, None, None])
            attn_bias = LowerTriangularMaskWithTensorBias(bias)
            input_metadata.attn_bias.append(attn_bias)

Did you find huggingface's alibi is diffrent from vllm's alibi of baichuan? One is increasing from 0, another is increasing to 0.

I do find the vllm's baichuan results are misaligned with huggingface's, I will check it and try your code later!

ericg108 commented 1 year ago

same here, and I'm using vllm 0.1.6. but it works fine when I used vllm 0.1.3..

mpdpey043 commented 1 year ago

i found comment in vllm src code(/vllm/model_executor/layers/attention.py line355), it says it's the same of those two biases: 图片