PaddlePaddle / PaddleHelix

Bio-Computing Platform Featuring Large-Scale Representation Learning and Multi-Task Deep Learning “螺旋桨”生物计算工具集
Apache License 2.0
801 stars 189 forks source link

Optimize the implementation of StructureModule. #242

Closed Xreki closed 1 year ago

Xreki commented 1 year ago

针对PR中的代码修改,我写了单测比较精度:https://gist.github.com/Xreki/f451fcb6c3dfe7d83d137b3f7c0ca3f1

收集了模型中rots_mul_rotsrots_mul_vecs输入输出的shape,发现主要存在2种配置。

a.shape b.shape out.shape 说明
不需要广播 [2, 256, 8, 3, 3] [2, 256, 8, 3, 3] [2, 256, 8, 3, 3] 原始实现需要107个算子,PR修改后只需要1个算子
需要广播 [2, 256, 1, 3, 3] [2, 256, 8, 3, 3] [2, 256, 8, 3, 3] 原始实现需要107个算子,PR修改后只需要3个算子
m.shape v.shape out.shape 说明
不需要广播 [2, 256, 14, 3, 3] [2, 256, 14, 3] [2, 256, 14, 3] PR修改后只需要3个算子
需要广播 [2, 256, 1, 3, 3] [2, 256, 8, 3] [2, 256, 8, 3] PR修改后只需要5个算子