Oneflow-Inc / one-codegeex

Apache License 2.0
7 stars 1 forks source link

fix fmha input bug #13

Closed BBuf closed 1 year ago

BBuf commented 1 year ago

fmha支持的input形状是 [batch_size, seq_length, hidden_size] ,经俊丞提醒这里有个错误的用法,对于输入的q,k,v(形状是[seq_length, batch_size, hidden_size])应该是transpose而不是reshape(view),只不过这里batch刚好是1所以不会影响输出,这个pr修复一下这个bug。