With my implementation and an old version from others, I got this assertion error:
But if I print one of the m, v, grad,it will pass (for both my implementation and another one):
Note that I only transform the m and v to ndl.Tensor for once and use the .data in calculation, and I can pass the previous test_optim_sgd_z_memory_check_1.
Could anyone explain this? I'm quite confused why this would influence the count of tensors. Thanks for help :)
With my implementation and an old version from others, I got this assertion error: But if I print one of the
m, v, grad
,it will pass (for both my implementation and another one): Note that I only transform them
andv
tondl.Tensor
for once and use the.data
in calculation, and I can pass the previoustest_optim_sgd_z_memory_check_1
.Could anyone explain this? I'm quite confused why this would influence the count of tensors. Thanks for help :)