mit-han-lab / llm-awq

[MLSys 2024 Best Paper Award] AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration
MIT License
2.08k stars 150 forks source link

illegal memory access when input tokens < 8 #170

Open casper-hansen opened 3 months ago

casper-hansen commented 3 months ago

Hi @ys-2020, thanks for your engineering work on the new kernels. I was made aware of a bug recently after importing the new GEMV/GEMM kernels into AutoAWQ. The issue specifically occurs on the GEMV kernel.

Conditions to trigger the bug:

Import into vLLM

The same illegal memory access occurs when trying to import it to vLLM.


I found a workaround that seems to allow us to use the GEMV kernel, but without triggering the illegal memory access. However, I wanted to post here to get your thoughts on this fix and if you can identify what the issue is with the imported kernels from TensorRT.

Replace this:


batch_size, n_tokens, _ = inputs.shape
if batch_size < 8 and n_tokens == 1:


Traceback (most recent call last):
  File "/workspace/AutoAWQ/examples/", line 25, in <module>
    generation_output = model.generate(
  File "/workspace/AutoAWQ/awq/models/", line 111, in generate
    return self.model.generate(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/", line 1544, in generate
    return self.greedy_search(
  File "/usr/local/lib/python3.10/dist-packages/transformers/generation/", line 2404, in greedy_search
    outputs = self(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/", line 1157, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/", line 1042, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/", line 757, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/", line 666, in forward
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/", line 160, in apply_rotary_pos_emb
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

CC @robertgshaw2-neuralmagic