codeplaysoftware / portBLAS

An implementation of BLAS using the SYCL open standard.
Apache License 2.0
250 stars 48 forks source link

Extended Gemm interface to support mixed precision operations #500

Closed OuadiElfarouki closed 4 months ago

OuadiElfarouki commented 6 months ago

Extend Gemm operator interface to support mixed precision operations, namely by decoupling matrix A and B type element_in_t from output matrix C and scalars alpha and beta type element_out_t.

Following oneMKL's spec notation for Gemm API : https://spec.oneapi.io/versions/latest/elements/oneMKL/source/domains/blas/gemm.html#onemkl-blas-gemm, this PR enables (Ta==Tb) to be set independently from (Tc==Ts). This feature has been enabled at a first stage for Ta=Tb=sycl::half and Tc=Ts=float. Thus enabling half support also enables the mixed precision case of (half, float) for gemm.

Changes include:

Note : Following oneMKL expected Gemm API, Support of bfloat16-float and std::int8_t-float would be straightforward afterwards, but the additional cases of Ta==Tb==Tc while Ts (alpha & beta) is separate will require additional decoupling & re-design work..

OuadiElfarouki commented 5 months ago

@s-Nick @Rbiessy @hjabird I've rebased and fixed some bugs in the current PR after recent commits. Feel free to check it ! Thanks.