wenet-e2e / wenet

Production First and Production Ready End-to-End Speech Recognition Toolkit
https://wenet-e2e.github.io/wenet/
Apache License 2.0
3.87k stars 1.04k forks source link

[transformer] refactor cache #2480

Closed Mddct closed 2 months ago

Mddct commented 2 months ago

目前transformer cache 有三种实现: 1 wenet 这种,torch.cat((k,v), dim=-1)

2 k, v 不需要concat 一起, 直接tupel(k,v), 由上层去决定怎么导出(jit, onnx etc)

3 LLM 的实现, 会把 k v cache 先申请max_len的, 然后tuple(k_cache, v_cache), 在attention 里边去写 copy (去写最一开始前申请好) 详细见:

该pr 封装函数 _update_kv_cache, 第一是避免到处同样的代码, 第二十可支持1,2类型 (未来有需求支持3)