bojone / bert4keras

keras implement of transformers for humans
https://kexue.fm/archives/6915
Apache License 2.0
5.33k stars 924 forks source link

bert4keras的attention_cache是否有传入方式的demo #481

Open hanggun opened 1 year ago

hanggun commented 1 year ago

提问时请尽可能提供如下信息:

基本信息

苏神您好,请问Transformer中看到您有在build中实现attention_cache参数的传入,请问该是要传入所有MHA层的名字么,例如bert-base需要传入12层MHA,[K缓存,V缓存]应该怎么传入,请问有简单实现的demo么,在unilm的模型上可以实现么

update

苏神,仔细阅读了代码后,发现K_cache和V_cache在bert4keras中是与inputs[1]和inputs[2]进行结合后输入到multihead attention,并且在multihead attention中也没有对cache机制进行特别的处理,想问一下attention cache在bert4keras中的作用是什么,在您的博客中也没有搜到相关的记录

hanggun commented 1 year ago

仔细思考之后,我发现Unilm是不能用attention cache的方式进行加速的,对于Encoder Decoder模型来说,K和V是由Encoder获得的,因为输入不变,所以 K和V在每次运算中都是固定的,只有Q在每次增加pred y的时候会改变,因此可以固定K和V,减少计算量,但是对于Unilm来说,由于采用的是单一Encoder,self-attention结构,每次输入的是X+y的拼接,因此,Q,K,V都在改变,无法固定Q和K来进行加速计算,但是在这个问题中https://github.com/bojone/bert4keras/issues/298,苏神说可以引入cache机制加速beam search的推理,那这里的cache机制应该加在哪里呢?从我的理解来看,cache机制只能加在seq2seq结构中,不知道有没有错

i4never commented 1 year ago

unilm也不是完全不行、只不过seq2seq的decoder self attention由于三角mask可以缓存softmax后的结果、unilm的话应该至少可以缓存qkv的投影,少做几次乘法运算 推销一篇我的blog看看

xv44586 commented 1 year ago

可以参考我这篇,完全基于bert4keras 实现的attention cache: faster-decoder