Closed BBuf closed 1 year ago
fmha支持的input形状是 [batch_size, seq_length, hidden_size] ,经俊丞提醒这里有个错误的用法,对于输入的q,k,v(形状是[seq_length, batch_size, hidden_size])应该是transpose而不是reshape(view),只不过这里batch刚好是1所以不会影响输出,这个pr修复一下这个bug。
fmha支持的input形状是 [batch_size, seq_length, hidden_size] ,经俊丞提醒这里有个错误的用法,对于输入的q,k,v(形状是[seq_length, batch_size, hidden_size])应该是transpose而不是reshape(view),只不过这里batch刚好是1所以不会影响输出,这个pr修复一下这个bug。