Closed Mddct closed 2 months ago
目前transformer cache 有三种实现: 1 wenet 这种,torch.cat((k,v), dim=-1)
2 k, v 不需要concat 一起, 直接tupel(k,v), 由上层去决定怎么导出(jit, onnx etc)
3 LLM 的实现, 会把 k v cache 先申请max_len的, 然后tuple(k_cache, v_cache), 在attention 里边去写 copy (去写最一开始前申请好) 详细见:
该pr 封装函数 _update_kv_cache, 第一是避免到处同样的代码, 第二十可支持1,2类型 (未来有需求支持3)
目前transformer cache 有三种实现: 1 wenet 这种,torch.cat((k,v), dim=-1)
2 k, v 不需要concat 一起, 直接tupel(k,v), 由上层去决定怎么导出(jit, onnx etc)
3 LLM 的实现, 会把 k v cache 先申请max_len的, 然后tuple(k_cache, v_cache), 在attention 里边去写 copy (去写最一开始前申请好) 详细见:
该pr 封装函数 _update_kv_cache, 第一是避免到处同样的代码, 第二十可支持1,2类型 (未来有需求支持3)