ROCm / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
19 stars 7 forks source link

roll back fmha/common.py #5

Closed tenpercent closed 7 months ago

tenpercent commented 7 months ago

Addressing

The current interface of MHA is that H has to always match between qkv. If you want to do GQA - e.g. one kv-head for every n q-heads, you have to send 5D inputs. (Thus we're forcing the user to be very explicit.) Do we really want to relax that rule in this PR?

And also merging test_mqa_forward into test_mqa_decoding as suggested in

To follow up on @bottler's question about @rocm_only - it'd be better for fmha.ck.FwOp to be covered by the generic test_forward and test_mqa_decoding. Then we don't need a separate test function (and eventually won't need @rocm_only after all such cases are refactored)

Not sure if this should be blocking the merge or can be done as a follow-up

qianfengz commented 7 months ago

@tenpercent

qianfengz commented 7 months ago

Since we added too many scripts for this function, we can just remove them, and I will keep the scripts in private for testing/verification.