Recently @mariecwhite has been adding s8s4s32 code paths to the mmt4d ukernel, including optimized code paths for arm64 but not for x86-64. This Issue is about adding the x86-64 pieces.
Explanation of "mmt4d": "matrix-times-matrix-transposed on 4D tensors" == our matrix-multiplication ukernel.
Explanation of "s8s4s32": this is the type triple describing the mmt4d op. Here s8 is the LHS element type = signed int8, s4 is the RHS element type = signed int4, s32 is the accumulator (output) element type.
Back on x86, here is a closely related existing kernel for s8s8s32 --- so it's almost the same, just the RHS is s8 instead of s4 so it doesn't need to do the extra work of unpacking two 4-bit values from each byte:
Tile sizes are given in M0xN0xK0 convention. For example, 8x32x4 means tile 8 along the M dimension, 32 along the N dimension, 4 along the K dimension.
Each ukernel can dictate its preferred tile size. You are free to choose what tile size you want here, but I would suggest that you start with what the existing s8s8s32 ukernel does on x86, and multiply its K0 tile size by 2 to account for the fact that 4 bit values are 2x smaller and you want to mask odd/even lanes and still have enough to feed your arithmetic instructions. So for instance on AVX2, the existing tile is Mx8x2, so you could start from Mx8x4 in your case. You will need to implement M values 1, 2, 4, 8 (M=8 is needed for the general case, other M values are needed for narrow problems such as vector-times-matrix).
Recently @mariecwhite has been adding
s8s4s32
code paths to the mmt4d ukernel, including optimized code paths for arm64 but not for x86-64. This Issue is about adding the x86-64 pieces.Explanation of "mmt4d": "matrix-times-matrix-transposed on 4D tensors" == our matrix-multiplication ukernel.
Explanation of "s8s4s32": this is the type triple describing the mmt4d op. Here
s8
is the LHS element type = signed int8,s4
is the RHS element type = signed int4,s32
is the accumulator (output) element type.Get familiar with the code:
mmt4d
ukernel: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/mmt4d.c#L115-L141s8s4s32
the x86-64 implementation just returns NULL here as there is nocase
for that so it hits thisdefault:
: https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_entry_point.c#L414-L415s8s8s32
--- so it's almost the same, just the RHS is s8 instead of s4 so it doesn't need to do the extra work of unpacking two 4-bit values from each byte:_avx512
files in this directoty have corresponding AVX512 cases.s16u4s32
. Theu
stands for unsigned. https://github.com/openxla/iree/blob/9d6d99f04c4a49dbc20fbd0656b829a8e000e260/runtime/src/iree/builtins/ukernel/arch/x86_64/mmt4d_x86_64_avx512_vnni.c#L189-L295Explanation of the tile sizes:
M0xN0xK0
convention. For example,8x32x4
means tile 8 along the M dimension, 32 along the N dimension, 4 along the K dimension.s8s8s32
ukernel does on x86, and multiply its K0 tile size by 2 to account for the fact that 4 bit values are 2x smaller and you want to mask odd/even lanes and still have enough to feed your arithmetic instructions. So for instance on AVX2, the existing tile is Mx8x2, so you could start from Mx8x4 in your case. You will need to implement M values 1, 2, 4, 8 (M=8 is needed for the general case, other M values are needed for narrow problems such as vector-times-matrix).How to run tests and micro benchmarks: