PaddlePaddle / Paddle

PArallel Distributed Deep LEarning: Machine Learning Framework from Industrial Practice (『飞桨』核心框架,深度学习&机器学习高性能单机、分布式训练和跨平台部署)
http://www.paddlepaddle.org/
Apache License 2.0
22.29k stars 5.61k forks source link

[API] Optimize `paddle.where` and `paddle.where_` in eager mode #69556

Open HydrogenSulfate opened 6 days ago

HydrogenSulfate commented 6 days ago

PR Category

Performance Optimization

PR Types

Improvements

Description

Pcard-75624

  1. paddle.wherepaddle.where_的原有实现通过多次调用基础二元运算符进行隐式广播,从而让cond, x,, y形状保持一致,优化之后使用无计算量的broadcast_shape计算广播后的形状,再使用broadcast_to进行广播,极大简化了代码逻辑,并且减少了不必要的前向和反向算子开销
  2. 由于减少了算子,重新适配了where_部分动态图单测,并且修复了个别静态图单测中,占位符形状和实际数据形状不一致的问题。

[!NOTE] 考虑到broadcast_shape和broadcast_to暂时没有完全适配动态shape,因此本PR只针对动态图分支进行修改

paddle-bot[bot] commented 6 days ago

你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.