Open Epliz opened 2 months ago
yeah this has been an issue for a while: https://github.com/ROCm/rocBLAS/issues/1238
I updated the kernel from my reproducer, it saturates memory bandwidth (contrary to rocBLAS).
I see that @daineAMD replied to the other issue, so mentioning here as well, in case that helps in any way. To contextualize again if needed, improving rocblas_gemm_ex for cases where it corresponds to gemv ops is a very common pattern for LLM inference at batch size = 1 which gets benchmarked quite often. Given that a ~100 lines kernel beats rocblas by 2x, I would recommend to put some efforts into this. At least for the matrix shapes of popular LLMs, you could make sure it gets decent performance.
its also pretty silly since just using the gemv kernels in these cases should be trivial
the suboptimiality of this is ofc also easly shown with rocblas's own tool:
rocblas-bench -f gemm_ex -m 1 -n 16192 -k 16192 --transposeA N --transposeB N -r s --compute_type s -i 50
transA,transB,M,N,K,alpha,lda,beta,ldb,ldc,ldd,batch_count,rocblas-Gflops,us
N,N,1,16192,16192,1,128,0,16192,128,128,1, 203.735, 2573.74
rocblas-bench -f gemv -r s -m 16192 -n 16192 --lda 16192 -i 50
transA,M,N,alpha,lda,incx,beta,incy,rocblas-Gflops,rocblas-GB/s,us
N,16192,16192,1,16192,1,0,1, 480.082, 960.224, 1092.3
also rocblas_hgemv would also be great since there is opportunity here to use dual-issue
Hi @Epliz, thanks for brining this up. Yes, the disparity between gemm with m == 1/n == 1 and gemv has been brought up in the past as noted by @IMbackK. Back when it was originally brought up, it wasn't straightforward on if the best approach would be to re-direct the gemm call to gemv (which has source kernels in rocblas) or to continue to gemm (which is handled within the Tensile library) since performance was somewhat of a mixed-bag; and handling this on a case-by-case basis seemed infeasible.
Regardless, it's good that this has been brought up again, and I'll discuss with the team on what the best approach is. If we can get gemv to outperform gemm in every case, then the changes to redirect to gemv would be straightforward, but most of the work would lie in ensuring that gemv is faster. I'll keep you updated with any progress here.
The request for rocblas_hgemv()
has also been noted and I can discuss with the team about whether or not we plan on supporting this.
Thanks, Daine
Hi @daineAMD
Thank you for the detailed comment on this matter and for:
The request for rocblas_hgemv() has also been noted and I can discuss with the team about whether or not we plan on supporting this.
Out of curiosity: On initial experimentation with rocblas-bench i have been unable to find a configuration where gemm_ex beats gemv on gfx906, gfx908 or gfx1030, if you have some notes on which these could be this would be interesting to me from a performance optimization perspective in my code.
Hi @daineAMD ,
Following up after a week. Do you have any example of a configuration where gemv is slower than gemm ?
If not, can you please proceed with making gemm call gemv for those cases?
If the rocBlas team cannot tackle this task, would a pull request from my side be potentially merged? I can sign whatever contribution agreement you might need.
Hi @Epliz and @IMbackK, sorry for the delay.
Looking at my past notes, it looks like the areas of most concern were where the incx parameter is large (with various exceptions), specifically gemm cases where (transA == transB == T && ldb >> 1) and (transA == transB == N && lda >> 1).
For example, the following gemm and gemv calls are essentially the same operation:
./rocblas-bench -f gemm -r f32_r --transposeA N --transposeB N -m 1 -n 2048 -k 2048 --alpha 1 --lda 2048 --beta 0 --ldb 2048 --ldc 1
and
./rocblas-bench -f gemv -r f32_r --transposeA T -m 2048 -n 2048 --lda 2048 --incx 2048
. Note the large incx
here which corresponds to the lda
in the gemm call. You can try this out yourself, but I'm getting better performance with gemm here than gemv on MI100.
Other cases where I'm seeing gemm perform better than gemv is for small sizes, e.g.:
./rocblas-bench -f gemm -r f32_r --transposeA N --transposeB N -m 1 -n 1024 -k 1024 --alpha 1 --lda 1 --beta 0 --ldb 1024 --ldc 1
and
./rocblas-bench -f gemv -r f32_r --transposeA T -m 1024 -n 1024 --lda 1024 --incx 1
I have a ticket to investigate further to see if we can call gemv from cases where it outperforms gemm and/or see what optimizations can be done for the current gemv to make this easier; I'll be looking at this in the coming weeks.
You are free to take a look yourself and open a PR, you can take a look at the contributing guide if you're interested, but merging the PR will still take some time as most of the work still lies in ensuring no performance regressions.
Thanks again, Daine
Hi @daineAMD,
Thank you for your examples, this has been useful in determining when to use gemv in my code to work around this issue an when not. Since this issue has now been quiet for a month and the previous issue on this topic was never resolved after two years I think it prudent to follow up on this and inquire if any internal progress or a decision the way forward with this performance problem has been made.
Hi @IMbackK,
Yes it's good to keep this topic up-to-date since it's been delayed for so long, thanks for your reminder. There have been no decisions made on a way forward yet. Currently, we are working on some potential optimizations for the gemv function, so I thought it best to hold off on making any changes until I can evaluate the performance of any changes to gemv in case it makes the decision easier.
In the meantime, I've mocked up some changes to potentially allow users to opt-in to using gemv kernels from rocblas_gemm_ex()
calls with m == 1 || n ==1
(and other restrictions). We'll be discussing this option once gemv changes mentioned prior are in.
Also, regarding half-precision gemv support, the following functions are in rocBLAS as of ROCm 6.0:
rocblas_hshgemv_batched()
/ rocblas_hshgemv_strided_batched()
rocblas_hssgemv_batched()
/ rocblas_hssgemv_strided_batched()
rocblas_tstgemv_batched()
/ rocblas_tstgemv_strided_batched()
rocblas_tssgemv_batched()
/ rocblas_tssgemv_strided_batched()
You can see their definitions in rocblas_functions.h. The precision prefixes represent input-compute-output types (e.g. hss is half-precision input, single-precision compute and output). It looks like they weren't added in the docs until ROCm 6.2, so they should be in the rocBLAS Documentation with ROCm 6.2. Sorry for not mentioning them previously, they slipped my mind.
Thanks, Daine
Thanks @daineAMD for the reply. I still believe that if not always dispatching those cases to the gemv kernel, dispatching for configurations known to be faster with gemv would be great already. If that would be helpful to you, I would be happy to provide some shapes that are used in open-weight LLMs where inference with batch size = 1 would see benefits from gemm to gemv lowering.
For example, for the mistral 7b model, the matrix shapes are:
Hi @daineAMD,
Thank you for your quick update.
In the meantime, I've mocked up some changes to potentially allow users to opt-in to using gemv kernels from
rocblas_gemm_ex()
calls withm == 1 || n ==1
(and other restrictions). We'll be discussing this option once gemv changes mentioned prior are in.
Having this selectable via rocblas's api or evvar would work great for me as a interim solution and presumably also for @Epliz.
Also, regarding half-precision gemv support, the following functions are in rocBLAS as of ROCm 6.0:
* `rocblas_hshgemv_batched()` / `rocblas_hshgemv_strided_batched()` * `rocblas_hssgemv_batched()`/ `rocblas_hssgemv_strided_batched()` * `rocblas_tstgemv_batched()` / `rocblas_tstgemv_strided_batched()` * `rocblas_tssgemv_batched()` / `rocblas_tssgemv_strided_batched()`
Indeed i was not aware of these functions due to the lack of documentation, thank you for bringing these to my attention! Thus far i have been up casting to fp32.
Thanks for your feedback on documentation @IMbackK, a missing space had obfuscated the changelog bullet for these functions which I just fixed and clarified in commit: f087847adef9bf535ccf5c45834183928017bdad
Describe the bug
As described in the title, rocblas_gemm_ex seems quite suboptimal when m==1 inputs/outputs are fp16 and compute is fp32 on MI100. A quite naive kernel I implemented beats it.
Causes https://github.com/ROCm/pytorch/issues/1408 in pytorch. It make LLM inference on Mistral 7b fp16 slower compared to what it could easily be.
To Reproduce
Here is a C++ reproducer:
Expected behavior
It should be at least as fast as my naive kernel. But running the above, I get:
Environment
environment.txt
Additional context
Add any other context about the problem here.
EDIT: put a better kernel than originally included one EDIT2: put a better kernel