Closed CloudyDory closed 5 months ago
Thanks for the report!
The problem can be fixed by changing the line
new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array) # accumulate gradients
into
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
Please let me know whether the changes fix the error.
Thank you very much for the reply, it fixes the error. Could you briefly explain why does it happen?
The error caused here is somehow not intuitive. This involves the issue of understanding the variable tracing in BrainPy. I do not encourage you to understand this error. :joy::joy:
After upgrading to BrainPy 2.5.0, I found that training by gradient accumulation does not work in the newest version.
We can use logistic regression as an example:
On BrainPy 2.4.6.post5, the above code trains normally. But on BrainPy 2.5.0, the above code generates the following error:
Environment (BrainPy 2.5.0):
Environment (BrainPy 2.4.6.post5):