Closed megemini closed 3 days ago
你的PR提交成功,感谢你对开源项目的贡献! 请关注后续CI自动化测试结果,详情请参考Paddle-CI手册。 Your PR has been submitted. Thanks for your contribution! Please wait for the result of CI firstly. See Paddle CI Manual for details.
@cxxly 昨晚我提交的时候,本地测试是通过的,结果中间 merge 了一下,单测就出问题了,精度的问题 ... ... 您那边有了解到 op_test 有什么变化吗?
但是上面这段对比 PyTorch 的代码没啥问题 ... ...
------ optimizer is : NAdam ------
------ compare cpu ------
------- compare finish ---------
------ compare gpu ------
W0419 07:04:03.844563 221339 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.2, Runtime API Version: 11.7
W0419 07:04:03.845705 221339 gpu_resources.cc:164] device: 0, cuDNN Version: 8.5.
------- compare finish ---------
------ optimizer is : RAdam ------
------ compare cpu ------
------- compare finish ---------
------ compare gpu ------
------- compare finish ---------
我也再看一下吧 ... ... 中间 merge 了几十个 pr ... ...
我看 merge 前的commit https://github.com/PaddlePaddle/Paddle/pull/63671/commits/2752082963dad54d8777727c6d481b8d2e664112 CI也木有过。昨天晚上只合入了迁移单测目录的PR,应该木有影响
我看 merge 前的commit 2752082 CI也木有过。昨天晚上只合入了迁移单测目录的PR,应该木有影响
那个 CI 没有跑完,因为提示有 conflict,所以我中间又 merge 了一下 ~
merge 之前我本地是通过的,merge 之后就出问题了 ~
刚才定位了一下,好像是 op_test 的逻辑有问题,比如,测试用例里面 nadam_step
是使用 numpy 模拟算子,这里应该只走了 1
步,但是,merge 之后,op_test 好像是走了 2
步,所以出问题了 ~
下面是我 op_test 里面 print 的一些输出:
!!!!! inputs
{'param': array([[ 0.17602904, 0.3982175 , -0.6236961 , ..., 0.6821805 ,
0.9441111 , 0.5645144 ],
[-0.605939 , 0.22125213, -0.04228899, ..., -0.7997418 ,
0.02449642, -0.02278226],
[-0.01884528, 0.50991696, 0.8460069 , ..., -0.05510011,
-0.30808878, -0.4063227 ],
...,
[-0.9146876 , -0.0642081 , 0.8627311 , ..., -0.33185562,
0.5504968 , 0.14223264],
[ 0.15039325, -0.45783746, 0.10544593, ..., 0.6430891 ,
0.8044486 , 0.12357192],
[-0.23231815, 0.17546286, 0.13293621, ..., -0.4734008 ,
0.47730142, -0.41493982]], dtype=float32), 'grad': array([[ 0.59897286, -0.96861345, -0.8794145 , ..., -0.9679654 ,
0.35871917, -0.5837766 ],
[ 0.32519364, -0.622245 , -0.33584148, ..., 0.10742767,
-0.6007129 , 0.2832118 ],
[ 0.37178418, 0.80451465, 0.23622747, ..., -0.65858865,
0.39903817, 0.9290114 ],
...,
[-0.53591865, 0.65896475, -0.92882305, ..., -0.14718425,
-0.56659144, 0.7059012 ],
[ 0.8947429 , 0.6457763 , 0.9349599 , ..., 0.24845287,
-0.41513 , -0.24681845],
[ 0.96900505, -0.4984258 , -0.21041307, ..., 0.25712886,
0.59644985, -0.53189933]], dtype=float32),
'momentum_decay_pow': array(0.884736, dtype=float32),
'beta2_pow': array(0.7660609, dtype=float32),
'mu_product': array(0.474552, dtype=float32),
'moment1': array([[ 0.56653005, 0.61273986, -0.00243789, ..., 0.80402356,
-0.61513036, 0.7725169 ],
[-0.76927406, 0.9316274 , 0.20868428, ..., -0.53259283,
0.52867246, -0.5667158 ],
[-0.07326724, 0.33402517, -0.1505982 , ..., -0.47788525,
-0.3946228 , 0.58341897],
...,
[-0.5391075 , 0.94511294, -0.9080138 , ..., 0.9197616 ,
-0.84012556, -0.73314774],
[ 0.65273803, 0.5870666 , 0.18124168, ..., 0.38077796,
-0.7243844 , -0.00442055],
[-0.75882095, 0.18046115, 0.09214061, ..., 0.4812588 ,
-0.8366428 , -0.67993087]], dtype=float32), 'moment2': array([[0.46832946, 0.7352264 , 0.9196088 , ..., 0.04938603, 0.8743996 ,
0.21438688],
[0.39720523, 0.2538434 , 0.5387814 , ..., 0.96552783, 0.09327286,
0.8940937 ],
[0.04364423, 0.79084235, 0.9966176 , ..., 0.19338697, 0.90396744,
0.7609614 ],
...,
[0.09379539, 0.4152249 , 0.4628522 , ..., 0.23871422, 0.80177 ,
0.47308218],
[0.86667293, 0.1730922 , 0.49166903, ..., 0.60770994, 0.8060526 ,
0.44773874],
[0.1273331 , 0.2625222 , 0.22494832, ..., 0.80136675, 0.05598379,
0.01957218]], dtype=float32), 'learning_rate': array(0.003, dtype=float32)}
!!!!! attrs
{'epsilon': 1e-08, 'beta1': 0.78, 'beta2': 0.915, 'momentum_decay': 0.004, 'momentum_decay_base': 0.96}
--------------------
param_out actual_np
[[ 0.1743741 0.39931554 -0.6224579 ... 0.6845692 0.9439621
0.56524587]
[-0.60594314 0.22185223 -0.04182015 ... -0.799576 0.02598498
-0.02284868]
[-0.02085438 0.50848264 0.84576494 ... -0.05261055 -0.30842945
-0.4081155 ]
...
[-0.91161144 -0.06639929 0.8652213 ... -0.33253878 0.551909
0.14145973]
[ 0.14870569 -0.46055287 0.10357524 ... 0.6423588 0.80555993
0.12408908]
[-0.23428343 0.17657013 0.13343854 ... -0.47411665 0.47615242
-0.4095929 ]]
param_out expect_np
[[ 0.1743592 0.3993776 -0.62242097 ... 0.6847349 0.9439293
0.56533515]
[-0.6059954 0.22194509 -0.04179393 ... -0.7995945 0.02609378
-0.0228765 ]
[-0.02092784 0.5084554 0.84575117 ... -0.05257884 -0.30845746
-0.4081416 ]
...
[-0.9115866 -0.06640426 0.86524254 ... -0.3324781 0.5519114
0.14139257]
[ 0.14868431 -0.46057892 0.10352963 ... 0.642358 0.8055586
0.12410427]
[-0.23441313 0.17661788 0.13346191 ... -0.4741149 0.4759967
-0.40957034]]
--------------------
momentum_decay_pow_out actual_np
0.8153727
momentum_decay_pow_out expect_np
0.8493466
--------------------
beta2_pow_out actual_np
0.64136535
beta2_pow_out expect_np
0.70094573
--------------------
mu_product_out actual_np
0.072285436
mu_product_out expect_np
0.18519613
--------------------
moment1_out actual_np
[[ 0.57366747 0.2648421 -0.19537276 ... 0.41418594 -0.40088344
0.47413227]
[-0.52849114 0.5897754 0.08888859 ... -0.3917883 0.28020766
-0.37973168]
[ 0.02464409 0.43753287 -0.06549654 ... -0.51764 -0.22001737
0.65944934]
...
[-0.53840595 0.8821603 -0.9125919 ... 0.6850335 -0.77994806
-0.4165569 ]
[ 0.7059791 0.59998274 0.34705973 ... 0.35166642 -0.65634847
-0.0577481 ]
[-0.37869918 0.031106 0.0255788 ... 0.4319502 -0.52136236
-0.6473639 ]]
moment1_out expect_np
[[ 0.57366747 0.26484212 -0.19537275 ... 0.414186 -0.40088344
0.4741323 ]
[-0.52849114 0.58977544 0.0888886 ... -0.3917883 0.2802077
-0.3797317 ]
[ 0.02464408 0.43753284 -0.06549655 ... -0.51764 -0.22001737
0.6594493 ]
...
[-0.5384059 0.8821603 -0.9125918 ... 0.6850335 -0.77994806
-0.41655692]
[ 0.7059791 0.59998274 0.34705967 ... 0.35166642 -0.65634847
-0.05774809]
[-0.3786992 0.03110601 0.0255788 ... 0.4319502 -0.52136236
-0.6473639 ]]
--------------------
moment2_out actual_np
[[0.45901677 0.75248015 0.90717846 ... 0.12482955 0.8110134 0.22513157]
[0.37243164 0.26517776 0.5025721 ... 0.88443893 0.11601742 0.8249135 ]
[0.05168346 0.77863646 0.9166484 ... 0.21381688 0.84066486 0.76963997]
...
[0.11023553 0.41684073 0.4968403 ... 0.22026488 0.76090676 0.4752254 ]
[0.86105376 0.19382665 0.52417994 ... 0.5613015 0.7521865 0.41485912]
[0.19632229 0.2613242 0.20959097 ... 0.73887044 0.08146411 0.04195648]]
moment2_out expect_np
[[0.45901677 0.7524802 0.90717846 ... 0.12482958 0.8110134 0.22513159]
[0.37243164 0.26517776 0.5025721 ... 0.88443893 0.11601742 0.8249135 ]
[0.05168346 0.7786365 0.9166484 ... 0.2138169 0.84066486 0.76963997]
...
[0.11023553 0.41684073 0.49684033 ... 0.22026488 0.76090676 0.47522542]
[0.86105376 0.19382666 0.52417994 ... 0.5613015 0.7521865 0.41485912]
[0.19632232 0.26132423 0.20959097 ... 0.73887044 0.08146413 0.04195648]]
可以看到,以 momentum_decay_pow 为例,公式为 momentum_decay_pow = momentum_decay_base,按照上面的输入值,其中 momentum_decay_pow = 0.884736,momentum_decay_base = 0.96,如果只走 1
步,则 momentum_decay_pow = 0.884736 0.96 = 0.84934656,这里与 expect_np 一致。但是,op_test 走了两步,momentum_decay_pow = 0.884736 0.96 0.96 = 0.815372698,这个与 actual_np 一样!expect_np 与 actual_np 不同,因此报错!
受影响的 beta2_pow 还有 mu_product,因此 param 也就不同了!
beta2_pow = 0.7660609 0.915 = 0.700945724 -> 1
步
beta2_pow = 0.7660609 0.915 * 0.915 = 0.641365337 -> 2
步
merge 之前是没问题的,我也有在 op_test 做类似的比对,确实是走了 1
步!
另外,PyTorch 的结果与我这里的算子一致,目前看,可能是 op_test 在 append op 的时候可能有点问题 ~ 这里只涉及静态图 ~
我看 merge 前的commit 2752082 CI也木有过。昨天晚上只合入了迁移单测目录的PR,应该木有影响
我 merge 之前的 branch 好像是两三个 星期 月 之前的了 ... ... 因为每次 merge 重新编译太耗时,偷懒了 ... ...
我 merge 之前的 branch 好像是两三个月 之前的了
要不要2分定位下?
我 merge 之前的 branch 好像是两三个月 之前的了
要不要2分定位下?
感谢!是我的问题,op_test 我在本地修改的 ... ...
抱歉浪费大家的时间了,等过了 CI 再 @ 大家帮忙评审 ~ 抱歉 ~
@cxxly CI 基本都已通过 ~ 之前是我这边本地单元测试的一些问题,已解决 ~ 请评审,谢谢 ~
Sorry to inform you that 60f8ebe's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
参考 test_optimizer,test_adam 相关的单测,补充详细的测试用例
请 @zyfncg Review算子相关规范
Sorry to inform you that f9cbe72's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.
momentum_decay_base
参数,统一改为 0.96
目前 CI 基本都已通过,不过,可能是这次单测增加的有点多,PR-CI-Coverage 超时了 ... ...
@cxxly 请帮忙看看怎么办?谢谢 ~
PR-CI-Coverage 超时了
可以用类似 set_tests_properties(test_autograd_dynamic PROPERTIES TIMEOUT 100)
来增大时间
另外,PR-CI-Coverage 中 radam / nadam 是通过的了,fail 的 api 应该不涉及这两个地方 ~
@cxxly 请评审 ~
@sunzhongkai588 请评审 ~
顺师傅,可以提交下中文文档。另外,PIR 流水线失败:
2024-05-14 22:17:04 The following tests FAILED:
2024-05-14 22:17:04 1036 - test_paddlescience (Failed)
2024-05-14 22:17:04 1067 - test_radam_op (Failed)
2024-05-14 22:17:04 1114 - test_paddlescience (Failed)
2024-05-14 22:17:04 1147 - test_radam_op (Failed)
2024-05-14 22:17:04 1066 - test_nadam_op (Failed)
2024-05-14 22:17:04 995 - test_nadam_op (Failed)
顺师傅,可以提交下中文文档。另外,PIR 流水线失败:
2024-05-14 22:17:04 The following tests FAILED: 2024-05-14 22:17:04 1036 - test_paddlescience (Failed) 2024-05-14 22:17:04 1067 - test_radam_op (Failed) 2024-05-14 22:17:04 1114 - test_paddlescience (Failed) 2024-05-14 22:17:04 1147 - test_radam_op (Failed) 2024-05-14 22:17:04 1066 - test_nadam_op (Failed) 2024-05-14 22:17:04 995 - test_nadam_op (Failed)
好像 PR-CI-Py3-PIR 之前不是 Required 就没太关注 ... ...
看了一下日志
2024-05-15 14:23:33 Traceback (most recent call last):
2024-05-15 14:23:33 File "/workspace/Paddle/build/test/legacy_test/test_nadam_op.py", line 292, in test_nadam_static
2024-05-15 14:23:33 conv = paddle.static.nn.conv2d(data, 8, 3)
2024-05-15 14:23:33 File "/workspace/Paddle/build/python/paddle/static/nn/common.py", line 1061, in conv2d
2024-05-15 14:23:33 helper.append_op(
2024-05-15 14:23:33 File "/workspace/Paddle/build/python/paddle/base/layer_helper.py", line 50, in append_op
2024-05-15 14:23:33 return self.main_program.current_block().append_op(*args, **kwargs)
2024-05-15 14:23:33 File "/workspace/Paddle/build/python/paddle/base/framework.py", line 4609, in append_op
2024-05-15 14:23:33 op = Operator(
2024-05-15 14:23:33 File "/workspace/Paddle/build/python/paddle/base/framework.py", line 3234, in __init__
2024-05-15 14:23:33 raise TypeError(
2024-05-15 14:23:33 TypeError: The type of '%Input' in operator conv2d should be one of [str, bytes, Variable]. but received : Value(define_op_name=pd_op.data, index=0, dtype=builtin.tensor<2x3x8x8xf32>, stop_gradient=True)
是测试用例里面 conv2d 在 pir 下调用有问题 ~
这几个测试用例是参考之前 ADAM 的一些用例,应该是之后 pir 的调用有变动 ~
那我改一下吧 ~
PR Category
User Experience
PR Types
New features
Description
NO.13 为 Paddle 新增 RAdam / NAdam API
关联 RFC:
本地测试通过,并且,使用以下代码,与 PyTorch 的结果进行比对,结果一致:
这里同时比对了 NAdam 和 RAdam 在
CPU
和GPU
上的优化后结果,结果一致。期间发现一个共性的问题:当优化步数较多时,
CPU
的精度与GPU
的精度会有大于 1e-06 的情况。GPU
下的精度与 PyTorch 一致性较好 ~ 其他 Paddle 的优化器也存在此类问题。特此指出 ~
另外,由于 Paddle 的优化器算法使用
accumulator
的方式,因此,虽然优化结果一样,但算法的具体实现逻辑上与原算法不一样~具体实现方式,可以参考
test_xxx_op.py
测试用例里面的radam_step
和nadam_step
函数~@cxxly 请评审 ~