Extend Gemm operator interface to support mixed precision operations, namely by decoupling matrix A and B type element_in_t from output matrix C and scalars alpha and beta type element_out_t.
Updating different Gemm kernel implementations to account for the decoupled types.
Necessary CMake and Kernel generation scripts updates to account for the couple of types instead of single type in gemm case.
Necessary changes to unit-tests to account for this feature in the Gemm case.
Note :
Following oneMKL expected Gemm API, Support of bfloat16-float and std::int8_t-float would be straightforward afterwards, but the additional cases of Ta==Tb==Tc while Ts(alpha & beta) is separate will require additional decoupling & re-design work..
Extend Gemm operator interface to support mixed precision operations, namely by decoupling matrix A and B type
element_in_t
from output matrix C and scalars alpha and beta typeelement_out_t
.Following oneMKL's spec notation for Gemm API : https://spec.oneapi.io/versions/latest/elements/oneMKL/source/domains/blas/gemm.html#onemkl-blas-gemm, this PR enables (Ta==Tb) to be set independently from (Tc==Ts). This feature has been enabled at a first stage for Ta=Tb=sycl::half and Tc=Ts=float. Thus enabling half support also enables the mixed precision case of (half, float) for gemm.
Changes include:
Note : Following oneMKL expected Gemm API, Support of
bfloat16-float
andstd::int8_t-float
would be straightforward afterwards, but the additional cases ofTa==Tb==Tc
whileTs
(alpha & beta) is separate will require additional decoupling & re-design work..