apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.81k stars 3.48k forks source link

Revert "[KVCACHE] Improved schedule for prefill attention" #17466

Closed MasterJH5574 closed 1 month ago

MasterJH5574 commented 1 month ago

This PR reverts apache/tvm#17432 as we observe a correctness issue when num_attention_heads is 28.

The correctness issue leads to incorrect end-to-end results in LLM inference.

MasterJH5574 commented 1 month ago

@krishnaraj36 Hi! Thank you for the great contribution on the prefill attention improvement. Unfortunately we just ran into a correctness issue caused by this PR and thus decide to temporarily revert it first. Particularly, the prefill kernel produces incorrect results when num_qo_heads is 28 (, num_kv_heads is 4, and the number of GQA groups is thus 7). The current unit test uses 32 as num_qo_heads, where the improved kernel works perfectly well and doesn't reveal the correctness issue.

Here is how you can reproduce the issue:

  1. replace this line with num_qo_heads=28 https://github.com/apache/tvm/blob/43f6c08f9db04adc73a17d3d99efdc6135ff0d3d/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py#L48
  2. run this test via
    python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Then it should be able to show the error like

~/W/tvm workspace ⇡1 *3 !2 ?2 ❯ python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Traceback (most recent call last):
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 964, in <module>
    test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 558, in test_paged_attention_kv_cache_prefill_and_decode
    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 468, in apply_attention
    tvm.testing.assert_allclose(
  File "/home/ruihang/Workspace/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 4852 / 10752 (45.1%)
Max absolute difference: 0.997
Max relative difference: 86.7
 x: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],...
 y: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],...

I think we are good to go with the improved kernels once the correctness issue is fixed. Would you mind taking a look at this issue? Thanks a lot in advance.

MasterJH5574 commented 1 month ago

BTW another information on the error: the kernel produces undetermined results, that is being said, if I run the test multiple times, each time the kernel produces a different results.

krishnaraj36 commented 1 month ago

@krishnaraj36 Hi! Thank you for the great contribution on the prefill attention improvement. Unfortunately we just ran into a correctness issue caused by this PR and thus decide to temporarily revert it first. Particularly, the prefill kernel produces incorrect results when num_qo_heads is 28 (, num_kv_heads is 4, and the number of GQA groups is thus 7). The current unit test uses 32 as num_qo_heads, where the improved kernel works perfectly well and doesn't reveal the correctness issue.

Here is how you can reproduce the issue:

  1. replace this line with num_qo_heads=28 https://github.com/apache/tvm/blob/43f6c08f9db04adc73a17d3d99efdc6135ff0d3d/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py#L48
  2. run this test via
    python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py

Then it should be able to show the error like

~/W/tvm workspace ⇡1 *3 !2 ?2 ❯ python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Traceback (most recent call last):
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 964, in <module>
    test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 558, in test_paged_attention_kv_cache_prefill_and_decode
    apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
  File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 468, in apply_attention
    tvm.testing.assert_allclose(
  File "/home/ruihang/Workspace/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0.001, atol=0.001

Mismatched elements: 4852 / 10752 (45.1%)
Max absolute difference: 0.997
Max relative difference: 86.7
 x: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],...
 y: array([[[[0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],
         [0.1501 , 0.9165 , 0.381  , ..., 0.7383 , 0.2344 , 0.04178],...

I think we are good to go with the improved kernels once the correctness issue is fixed. Would you mind taking a look at this issue? Thanks a lot in advance.

@MasterJH5574 : Thanks for reporting this issue, Sure we will look into this issue.