ROCm / triton

Development repository for the Triton language and compiler
MIT License
80 stars 22 forks source link

added bf16 support for perf-kernels #513

Closed jtang10 closed 4 months ago

jtang10 commented 4 months ago

06-fused-attention-fwd-transV.py:

  1. Added bfloat16. test_op_fwd has 2 testcases failing, each of which has two elements in the final output right outside the error threshold.
  2. Can skip fp8 is it is not available in the docker image.

06-fused-attention-transV.py:

  1. Enabled bfloat16 for both forward and backward, though backward pass fails the testcases. Since we don't really run backward now, I left it aside.
vgokhale commented 4 months ago

Could you also make similar changes to flash-attention.py?

jtang10 commented 4 months ago

Could you also make similar changes to flash-attention.py?

Do you want me to do it here, or can I fold that with the fp8 in flash-attention.py. Though bf16 and fp8 are separate efforts, I personally prefer to consolidate them in flash-attention.py for less overhead, if that is okay.