riscv-non-isa / rvv-intrinsic-doc

https://jira.riscv.org/browse/RVG-153
BSD 3-Clause "New" or "Revised" License
283 stars 89 forks source link

Intrinsic Support for BF16 Extension #223

Open joshua-arch1 opened 1 year ago

joshua-arch1 commented 1 year ago

The BF16 Extension has recently been proposed with three extra Instruction Set Extensions. https://github.com/riscv/riscv-bfloat16

I'm wondering how we plan to address existing rvv intrinsics. Do we need to add new intrinsics tailored for bf16 datatype? If so, do we need to give a BF16 version for all the intrinsics with floating-point types? Maybe we can raise an issue for discussion.

joshua-arch1 commented 1 year ago

Maybe we can reuse most rvv intrinsics unless we want to generate new Zvfbfmin/Zvfbfwma instructions.

I have checked vector BF16 widening multiply-accumulate llvm implementation in AArch64 sve. It uses llvm.aarch64.sve.fmlalb.nxv4f32 for fmlalb and llvm.aarch64.sve.bfmlalb for bfmlalb. Therefore what is certain for RISCV now is to define a new intrinsic for vfwmaccbf16.

That is to say, for llvm.riscv.vfwmacc.nxv8f32.nxv8f16, we need to use llvm.riscv.vfwmaccbf16 in a given bf16 format.

kito-cheng commented 1 year ago

We definitely need to introduce new types like vbfloat16m1_t and corresponding RVV C intrinsic API.

Maybe we can reuse most rvv intrinsics unless we want to generate new Zvfbfmin/Zvfbfwma instructions.

I have checked vector BF16 widening multiply-accumulate implementation in AArch64 sve. It uses > llvm.aarch64.sve.fmlalb.nxv4f32 for fmlalb and llvm.aarch64.sve.bfmlalb for bfmlalb. Therefore what is certain for RISCV now >is to define a new intrinsic for vfwmaccbf16.

That is to say, for llvm.riscv.vfwmacc.nxv8f32.nxv8f16, we need to use llvm.riscv.vfwmaccbf16 in a given bf16 format.

That's LLVM implementation detail which should not specified in the intrinsic API spec.

joshua-arch1 commented 1 year ago

We definitely need to introduce new types like vbfloat16m1_t and corresponding RVV C intrinsic API.

Maybe we can reuse most rvv intrinsics unless we want to generate new Zvfbfmin/Zvfbfwma instructions. I have checked vector BF16 widening multiply-accumulate implementation in AArch64 sve. It uses > llvm.aarch64.sve.fmlalb.nxv4f32 for fmlalb and llvm.aarch64.sve.bfmlalb for bfmlalb. Therefore what is certain for RISCV now >is to define a new intrinsic for vfwmaccbf16. That is to say, for llvm.riscv.vfwmacc.nxv8f32.nxv8f16, we need to use llvm.riscv.vfwmaccbf16 in a given bf16 format.

That's LLVM implementation detail which should not specified in the intrinsic API spec.

But I don't think we need to add bfloat16 type for all the rvv floating-point intrinsics if we define a function to convert bf16 to fp32/fp16. Z(v)fbfmin has corresponding instructions.

kito-cheng commented 1 year ago

But I don't think we need to add bfloat16 type for all the rvv floating-point intrinsics if we define a function to convert bf16 to fp32/fp16. Z(v)fbfmin has corresponding instructions.

At least we should define intrinsic for convert instruction, and define __riscvvfwmaccbf16[vv|vf]_bf16* for zvfbfwma, also some type utils functions like reinterpret.

joshua-arch1 commented 1 year ago

So I'll add bf16-format instrinsics for the following functions with float16 type.

Reinterpret Cast Conversion Functions:
vfloat16mf4_t __riscv_vreinterpret_v_i16mf4_f16mf4 (vint16mf4_t src);
vfloat16mf2_t __riscv_vreinterpret_v_i16mf2_f16mf2 (vint16mf2_t src);
vfloat16m1_t __riscv_vreinterpret_v_i16m1_f16m1 (vint16m1_t src);
vfloat16m2_t __riscv_vreinterpret_v_i16m2_f16m2 (vint16m2_t src);
vfloat16m4_t __riscv_vreinterpret_v_i16m4_f16m4 (vint16m4_t src);
vfloat16m8_t __riscv_vreinterpret_v_i16m8_f16m8 (vint16m8_t src);
vfloat16mf4_t __riscv_vreinterpret_v_u16mf4_f16mf4 (vuint16mf4_t src);
vfloat16mf2_t __riscv_vreinterpret_v_u16mf2_f16mf2 (vuint16mf2_t src);
vfloat16m1_t __riscv_vreinterpret_v_u16m1_f16m1 (vuint16m1_t src);
vfloat16m2_t __riscv_vreinterpret_v_u16m2_f16m2 (vuint16m2_t src);
vfloat16m4_t __riscv_vreinterpret_v_u16m4_f16m4 (vuint16m4_t src);
vfloat16m8_t __riscv_vreinterpret_v_u16m8_f16m8 (vuint16m8_t src);
vint16mf4_t __riscv_vreinterpret_v_f16mf4_i16mf4 (vfloat16mf4_t src);
vint16mf2_t __riscv_vreinterpret_v_f16mf2_i16mf2 (vfloat16mf2_t src);
vint16m1_t __riscv_vreinterpret_v_f16m1_i16m1 (vfloat16m1_t src);
vint16m2_t __riscv_vreinterpret_v_f16m2_i16m2 (vfloat16m2_t src);
vint16m4_t __riscv_vreinterpret_v_f16m4_i16m4 (vfloat16m4_t src);
vint16m8_t __riscv_vreinterpret_v_f16m8_i16m8 (vfloat16m8_t src);
vuint16mf4_t __riscv_vreinterpret_v_f16mf4_u16mf4 (vfloat16mf4_t src);
vuint16mf2_t __riscv_vreinterpret_v_f16mf2_u16mf2 (vfloat16mf2_t src);
vuint16m1_t __riscv_vreinterpret_v_f16m1_u16m1 (vfloat16m1_t src);
vuint16m2_t __riscv_vreinterpret_v_f16m2_u16m2 (vfloat16m2_t src);
vuint16m4_t __riscv_vreinterpret_v_f16m4_u16m4 (vfloat16m4_t src);
vuint16m8_t __riscv_vreinterpret_v_f16m8_u16m8 (vfloat16m8_t src);

Single-Width Floating-Point/Integer Type-Convert Functions:
vint16mf4_t __riscv_vfcvt_x_f_v_i16mf4 (vfloat16mf4_t src, size_t vl);
vint16mf4_t __riscv_vfcvt_rtz_x_f_v_i16mf4 (vfloat16mf4_t src, size_t vl);
vint16mf2_t __riscv_vfcvt_x_f_v_i16mf2 (vfloat16mf2_t src, size_t vl);
vint16mf2_t __riscv_vfcvt_rtz_x_f_v_i16mf2 (vfloat16mf2_t src, size_t vl);
vint16m1_t __riscv_vfcvt_x_f_v_i16m1 (vfloat16m1_t src, size_t vl);
vint16m1_t __riscv_vfcvt_rtz_x_f_v_i16m1 (vfloat16m1_t src, size_t vl);
vint16m2_t __riscv_vfcvt_x_f_v_i16m2 (vfloat16m2_t src, size_t vl);
vint16m2_t __riscv_vfcvt_rtz_x_f_v_i16m2 (vfloat16m2_t src, size_t vl);
vint16m4_t __riscv_vfcvt_x_f_v_i16m4 (vfloat16m4_t src, size_t vl);
vint16m4_t __riscv_vfcvt_rtz_x_f_v_i16m4 (vfloat16m4_t src, size_t vl);
vint16m8_t __riscv_vfcvt_x_f_v_i16m8 (vfloat16m8_t src, size_t vl);
vint16m8_t __riscv_vfcvt_rtz_x_f_v_i16m8 (vfloat16m8_t src, size_t vl);
vuint16mf4_t __riscv_vfcvt_xu_f_v_u16mf4 (vfloat16mf4_t src, size_t vl);
vuint16mf4_t __riscv_vfcvt_rtz_xu_f_v_u16mf4 (vfloat16mf4_t src, size_t vl);
vuint16mf2_t __riscv_vfcvt_xu_f_v_u16mf2 (vfloat16mf2_t src, size_t vl);
vuint16mf2_t __riscv_vfcvt_rtz_xu_f_v_u16mf2 (vfloat16mf2_t src, size_t vl);
vuint16m1_t __riscv_vfcvt_xu_f_v_u16m1 (vfloat16m1_t src, size_t vl);
vuint16m1_t __riscv_vfcvt_rtz_xu_f_v_u16m1 (vfloat16m1_t src, size_t vl);
vuint16m2_t __riscv_vfcvt_xu_f_v_u16m2 (vfloat16m2_t src, size_t vl);
vuint16m2_t __riscv_vfcvt_rtz_xu_f_v_u16m2 (vfloat16m2_t src, size_t vl);
vuint16m4_t __riscv_vfcvt_xu_f_v_u16m4 (vfloat16m4_t src, size_t vl);
vuint16m4_t __riscv_vfcvt_rtz_xu_f_v_u16m4 (vfloat16m4_t src, size_t vl);
vuint16m8_t __riscv_vfcvt_xu_f_v_u16m8 (vfloat16m8_t src, size_t vl);
vuint16m8_t __riscv_vfcvt_rtz_xu_f_v_u16m8 (vfloat16m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfcvt_f_x_v_f16mf4 (vint16mf4_t src, size_t vl);
vfloat16mf2_t __riscv_vfcvt_f_x_v_f16mf2 (vint16mf2_t src, size_t vl);
vfloat16m1_t __riscv_vfcvt_f_x_v_f16m1 (vint16m1_t src, size_t vl);
vfloat16m2_t __riscv_vfcvt_f_x_v_f16m2 (vint16m2_t src, size_t vl);
vfloat16m4_t __riscv_vfcvt_f_x_v_f16m4 (vint16m4_t src, size_t vl);
vfloat16m8_t __riscv_vfcvt_f_x_v_f16m8 (vint16m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfcvt_f_xu_v_f16mf4 (vuint16mf4_t src, size_t vl);
vfloat16mf2_t __riscv_vfcvt_f_xu_v_f16mf2 (vuint16mf2_t src, size_t vl);
vfloat16m1_t __riscv_vfcvt_f_xu_v_f16m1 (vuint16m1_t src, size_t vl);
vfloat16m2_t __riscv_vfcvt_f_xu_v_f16m2 (vuint16m2_t src, size_t vl);
vfloat16m4_t __riscv_vfcvt_f_xu_v_f16m4 (vuint16m4_t src, size_t vl);
vfloat16m8_t __riscv_vfcvt_f_xu_v_f16m8 (vuint16m8_t src, size_t vl);

Widening Floating-Point/Integer Type-Convert Functions:
vfloat16mf4_t __riscv_vfwcvt_f_x_v_f16mf4 (vint8mf8_t src, size_t vl);
vfloat16mf2_t __riscv_vfwcvt_f_x_v_f16mf2 (vint8mf4_t src, size_t vl);
vfloat16m1_t __riscv_vfwcvt_f_x_v_f16m1 (vint8mf2_t src, size_t vl);
vfloat16m2_t __riscv_vfwcvt_f_x_v_f16m2 (vint8m1_t src, size_t vl);
vfloat16m4_t __riscv_vfwcvt_f_x_v_f16m4 (vint8m2_t src, size_t vl);
vfloat16m8_t __riscv_vfwcvt_f_x_v_f16m8 (vint8m4_t src, size_t vl);
vfloat16mf4_t __riscv_vfwcvt_f_xu_v_f16mf4 (vuint8mf8_t src, size_t vl);
vfloat16mf2_t __riscv_vfwcvt_f_xu_v_f16mf2 (vuint8mf4_t src, size_t vl);
vfloat16m1_t __riscv_vfwcvt_f_xu_v_f16m1 (vuint8mf2_t src, size_t vl);
vfloat16m2_t __riscv_vfwcvt_f_xu_v_f16m2 (vuint8m1_t src, size_t vl);
vfloat16m4_t __riscv_vfwcvt_f_xu_v_f16m4 (vuint8m2_t src, size_t vl);
vfloat16m8_t __riscv_vfwcvt_f_xu_v_f16m8 (vuint8m4_t src, size_t vl);
vint32mf2_t __riscv_vfwcvt_x_f_v_i32mf2 (vfloat16mf4_t src, size_t vl);
vint32mf2_t __riscv_vfwcvt_rtz_x_f_v_i32mf2 (vfloat16mf4_t src, size_t vl);
vint32m1_t __riscv_vfwcvt_x_f_v_i32m1 (vfloat16mf2_t src, size_t vl);
vint32m1_t __riscv_vfwcvt_rtz_x_f_v_i32m1 (vfloat16mf2_t src, size_t vl);
vint32m2_t __riscv_vfwcvt_x_f_v_i32m2 (vfloat16m1_t src, size_t vl);
vint32m2_t __riscv_vfwcvt_rtz_x_f_v_i32m2 (vfloat16m1_t src, size_t vl);
vint32m4_t __riscv_vfwcvt_x_f_v_i32m4 (vfloat16m2_t src, size_t vl);
vint32m4_t __riscv_vfwcvt_rtz_x_f_v_i32m4 (vfloat16m2_t src, size_t vl);
vint32m8_t __riscv_vfwcvt_x_f_v_i32m8 (vfloat16m4_t src, size_t vl);
vint32m8_t __riscv_vfwcvt_rtz_x_f_v_i32m8 (vfloat16m4_t src, size_t vl);
vuint32mf2_t __riscv_vfwcvt_xu_f_v_u32mf2 (vfloat16mf4_t src, size_t vl);
vuint32mf2_t __riscv_vfwcvt_rtz_xu_f_v_u32mf2 (vfloat16mf4_t src, size_t vl);
vuint32m1_t __riscv_vfwcvt_xu_f_v_u32m1 (vfloat16mf2_t src, size_t vl);
vuint32m1_t __riscv_vfwcvt_rtz_xu_f_v_u32m1 (vfloat16mf2_t src, size_t vl);
vuint32m2_t __riscv_vfwcvt_xu_f_v_u32m2 (vfloat16m1_t src, size_t vl);
vuint32m2_t __riscv_vfwcvt_rtz_xu_f_v_u32m2 (vfloat16m1_t src, size_t vl);
vuint32m4_t __riscv_vfwcvt_xu_f_v_u32m4 (vfloat16m2_t src, size_t vl);
vuint32m4_t __riscv_vfwcvt_rtz_xu_f_v_u32m4 (vfloat16m2_t src, size_t vl);
vuint32m8_t __riscv_vfwcvt_xu_f_v_u32m8 (vfloat16m4_t src, size_t vl);
vuint32m8_t __riscv_vfwcvt_rtz_xu_f_v_u32m8 (vfloat16m4_t src, size_t vl);

Narrowing Floating-Point/Integer Type-Convert Functions:
vint8mf8_t __riscv_vfncvt_x_f_w_i8mf8 (vfloat16mf4_t src, size_t vl);
vint8mf8_t __riscv_vfncvt_rtz_x_f_w_i8mf8 (vfloat16mf4_t src, size_t vl);
vint8mf4_t __riscv_vfncvt_x_f_w_i8mf4 (vfloat16mf2_t src, size_t vl);
vint8mf4_t __riscv_vfncvt_rtz_x_f_w_i8mf4 (vfloat16mf2_t src, size_t vl);
vint8mf2_t __riscv_vfncvt_x_f_w_i8mf2 (vfloat16m1_t src, size_t vl);
vint8mf2_t __riscv_vfncvt_rtz_x_f_w_i8mf2 (vfloat16m1_t src, size_t vl);
vint8m1_t __riscv_vfncvt_x_f_w_i8m1 (vfloat16m2_t src, size_t vl);
vint8m1_t __riscv_vfncvt_rtz_x_f_w_i8m1 (vfloat16m2_t src, size_t vl);
vint8m2_t __riscv_vfncvt_x_f_w_i8m2 (vfloat16m4_t src, size_t vl);
vint8m2_t __riscv_vfncvt_rtz_x_f_w_i8m2 (vfloat16m4_t src, size_t vl);
vint8m4_t __riscv_vfncvt_x_f_w_i8m4 (vfloat16m8_t src, size_t vl);
vint8m4_t __riscv_vfncvt_rtz_x_f_w_i8m4 (vfloat16m8_t src, size_t vl);
vuint8mf8_t __riscv_vfncvt_xu_f_w_u8mf8 (vfloat16mf4_t src, size_t vl);
vuint8mf8_t __riscv_vfncvt_rtz_xu_f_w_u8mf8 (vfloat16mf4_t src, size_t vl);
vuint8mf4_t __riscv_vfncvt_xu_f_w_u8mf4 (vfloat16mf2_t src, size_t vl);
vuint8mf4_t __riscv_vfncvt_rtz_xu_f_w_u8mf4 (vfloat16mf2_t src, size_t vl);
vuint8mf2_t __riscv_vfncvt_xu_f_w_u8mf2 (vfloat16m1_t src, size_t vl);
vuint8mf2_t __riscv_vfncvt_rtz_xu_f_w_u8mf2 (vfloat16m1_t src, size_t vl);
vuint8m1_t __riscv_vfncvt_xu_f_w_u8m1 (vfloat16m2_t src, size_t vl);
vuint8m1_t __riscv_vfncvt_rtz_xu_f_w_u8m1 (vfloat16m2_t src, size_t vl);
vuint8m2_t __riscv_vfncvt_xu_f_w_u8m2 (vfloat16m4_t src, size_t vl);
vuint8m2_t __riscv_vfncvt_rtz_xu_f_w_u8m2 (vfloat16m4_t src, size_t vl);
vuint8m4_t __riscv_vfncvt_xu_f_w_u8m4 (vfloat16m8_t src, size_t vl);
vuint8m4_t __riscv_vfncvt_rtz_xu_f_w_u8m4 (vfloat16m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfncvt_f_x_w_f16mf4 (vint32mf2_t src, size_t vl);
vfloat16mf2_t __riscv_vfncvt_f_x_w_f16mf2 (vint32m1_t src, size_t vl);
vfloat16m1_t __riscv_vfncvt_f_x_w_f16m1 (vint32m2_t src, size_t vl);
vfloat16m2_t __riscv_vfncvt_f_x_w_f16m2 (vint32m4_t src, size_t vl);
vfloat16m4_t __riscv_vfncvt_f_x_w_f16m4 (vint32m8_t src, size_t vl);
vfloat16mf4_t __riscv_vfncvt_f_xu_w_f16mf4 (vuint32mf2_t src, size_t vl);
vfloat16mf2_t __riscv_vfncvt_f_xu_w_f16mf2 (vuint32m1_t src, size_t vl);
vfloat16m1_t __riscv_vfncvt_f_xu_w_f16m1 (vuint32m2_t src, size_t vl);
vfloat16m2_t __riscv_vfncvt_f_xu_w_f16m2 (vuint32m4_t src, size_t vl);
vfloat16m4_t __riscv_vfncvt_f_xu_w_f16m4 (vuint32m8_t src, size_t vl);

As for __riscv_vfwcvt_f_f_v_f32 and __riscv_vfncvt_f_f_w_f16, I prefer to use a new format according to vfwcvtbf16.f.f.vand vfncvtbf16.f.f.w in the new Zvfbfmin Extension, so I didn't include them.

fuhle044 commented 1 year ago

But I don't think we need to add bfloat16 type for all the rvv floating-point intrinsics if we define a function to convert bf16 to fp32/fp16. Z(v)fbfmin has corresponding instructions.

At least we should define intrinsic for convert instruction, and define __riscvvfwmaccbf16[vv|vf]_bf16* for zvfbfwma, also some type utils functions like reinterpret.

Would it make sense to introduce BF16 load/store intrinsics that do a 16-bit integer load followed by a reinterpret cast? This would simplify the user interface considerably.

joshua-arch1 commented 1 year ago

Thank you for your suggestion. I'll add load/store intrinsics in my PR.