Tlntin / trt2023

Apache License 2.0
23 stars 3 forks source link

[help wanted] 请问multiHeadCrossAttentionPlugin 是不支持 kv seq > 128 的情况吗? #1

Closed yibolu96 closed 10 months ago

yibolu96 commented 10 months ago

你好大佬,我在测试 cross attention 的时候发现这一行写死了 kv seqlen 的长度,这是为啥呢,我尝试更改这个数字之后发现会报错。以及当我真实的 kv seq len 超过128 维虽然不会报错,但是计算结果会有误。这是为啥呢?

https://github.com/Tlntin/trt2023/blob/5c46e815b65906efae255d02f6bddb3d67fea68b/plugin/multiHeadCrossAttentionPlugin/fmhcaPlugin.cpp#L290

Tlntin commented 10 months ago

比赛时未采用该plugin,这个是从8.6.0那边copy过来然后移植到8.6.1做一下测试而已。 官方最终给出的demo链接:https://github.com/NVIDIA/trt-samples-for-hackathon-cn/tree/master/Hackathon2023/controlnet 可以参考一下这个方案。

yibolu96 commented 10 months ago

感谢大佬