Closed MasterJH5574 closed 1 month ago
@krishnaraj36 Hi! Thank you for the great contribution on the prefill attention improvement. Unfortunately we just ran into a correctness issue caused by this PR and thus decide to temporarily revert it first. Particularly, the prefill kernel produces incorrect results when num_qo_heads
is 28 (, num_kv_heads
is 4, and the number of GQA groups is thus 7). The current unit test uses 32 as num_qo_heads
, where the improved kernel works perfectly well and doesn't reveal the correctness issue.
Here is how you can reproduce the issue:
num_qo_heads=28
https://github.com/apache/tvm/blob/43f6c08f9db04adc73a17d3d99efdc6135ff0d3d/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py#L48python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Then it should be able to show the error like
~/W/tvm workspace ⇡1 *3 !2 ?2 ❯ python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Traceback (most recent call last):
File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 964, in <module>
test_paged_attention_kv_cache_prefill_and_decode(cache_and_config)
File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 558, in test_paged_attention_kv_cache_prefill_and_decode
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v)
File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 468, in apply_attention
tvm.testing.assert_allclose(
File "/home/ruihang/Workspace/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose
np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)
File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose
assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare
raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=0.001, atol=0.001
Mismatched elements: 4852 / 10752 (45.1%)
Max absolute difference: 0.997
Max relative difference: 86.7
x: array([[[[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178],
[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178],
[0. , 0. , 0. , ..., 0. , 0. , 0. ],...
y: array([[[[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178],
[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178],
[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178],...
I think we are good to go with the improved kernels once the correctness issue is fixed. Would you mind taking a look at this issue? Thanks a lot in advance.
BTW another information on the error: the kernel produces undetermined results, that is being said, if I run the test multiple times, each time the kernel produces a different results.
@krishnaraj36 Hi! Thank you for the great contribution on the prefill attention improvement. Unfortunately we just ran into a correctness issue caused by this PR and thus decide to temporarily revert it first. Particularly, the prefill kernel produces incorrect results when
num_qo_heads
is 28 (,num_kv_heads
is 4, and the number of GQA groups is thus 7). The current unit test uses 32 asnum_qo_heads
, where the improved kernel works perfectly well and doesn't reveal the correctness issue.Here is how you can reproduce the issue:
- replace this line with
num_qo_heads=28
https://github.com/apache/tvm/blob/43f6c08f9db04adc73a17d3d99efdc6135ff0d3d/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py#L48- run this test via
python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
Then it should be able to show the error like
~/W/tvm workspace ⇡1 *3 !2 ?2 ❯ python tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py Traceback (most recent call last): File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 964, in <module> test_paged_attention_kv_cache_prefill_and_decode(cache_and_config) File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 558, in test_paged_attention_kv_cache_prefill_and_decode apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v) File "/home/ruihang/Workspace/tvm/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py", line 468, in apply_attention tvm.testing.assert_allclose( File "/home/ruihang/Workspace/tvm/python/tvm/testing/utils.py", line 120, in assert_allclose np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True) File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 1504, in assert_allclose assert_array_compare(compare, actual, desired, err_msg=str(err_msg), File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ File "/home/ruihang/Workspace/miniconda3/envs/python311/lib/python3.11/site-packages/numpy/testing/_private/utils.py", line 797, in assert_array_compare raise AssertionError(msg) AssertionError: Not equal to tolerance rtol=0.001, atol=0.001 Mismatched elements: 4852 / 10752 (45.1%) Max absolute difference: 0.997 Max relative difference: 86.7 x: array([[[[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178], [0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178], [0. , 0. , 0. , ..., 0. , 0. , 0. ],... y: array([[[[0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178], [0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178], [0.1501 , 0.9165 , 0.381 , ..., 0.7383 , 0.2344 , 0.04178],...
I think we are good to go with the improved kernels once the correctness issue is fixed. Would you mind taking a look at this issue? Thanks a lot in advance.
@MasterJH5574 : Thanks for reporting this issue, Sure we will look into this issue.
This PR reverts apache/tvm#17432 as we observe a correctness issue when
num_attention_heads
is 28.The correctness issue leads to incorrect end-to-end results in LLM inference.