alexzhang13 / flashattention2-custom-mask

Triton implementation of FlashAttention2 that adds Custom Masks.
Apache License 2.0
62 stars 5 forks source link

Problems with running the small example in readme file #4

Closed eigenvectorBazuz closed 1 month ago

eigenvectorBazuz commented 1 month ago

I tried to run the given small examples in a new conda env with triton 3.0.0 and torch and encountered two problems:

(1)

TypeError: apply() takes no keyword arguments Which I solved by sending the arguments unnamed (i.e. mask instead of mask=mask etc.) But then I ran into a thornier problem:

(2)

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
~/.local/lib/python3.9/site-packages/triton/language/core.py in wrapper(*args, **kwargs)
     34                              "(`_builder` argument must be provided outside of JIT functions.)")
---> 35         return fn(*args, **kwargs)
     36 

~/.local/lib/python3.9/site-packages/triton/language/core.py in dot(input, other, acc, input_precision, allow_tf32, max_num_imprecise_acc, out_dtype, _builder)
   1533     max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc)
-> 1534     return semantic.dot(input, other, acc, input_precision, max_num_imprecise_acc, out_dtype, _builder)
   1535 

~/.local/lib/python3.9/site-packages/triton/language/semantic.py in dot(lhs, rhs, acc, input_precision, max_num_imprecise_acc, out_dtype, builder)
   1354     assert lhs.type.is_block() and rhs.type.is_block()
-> 1355     assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.options)
   1356     if lhs.dtype.is_fp8e4b15() or rhs.dtype.is_fp8e4b15():

~/.local/lib/python3.9/site-packages/triton/language/semantic.py in assert_dtypes_valid(lhs_dtype, rhs_dtype, options)
   1327                 return
-> 1328             assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!"
   1329         else:

AssertionError: First input (fp16) and second input (fp32) must have the same dtype!

This I am at a loss how to solve on my own...

Uwwal commented 1 month ago

This is because the QK^T dtype will be converted to fp16 in the code, and I will submit a PR later

alexzhang13 commented 1 month ago

Merged @Uwwal's changes! Let me know if this fixes your issue @eigenvectorBazuz, otherwise I'll take a look at it more deeply later this week.