Closed juanitorduz closed 8 months ago
Looks great to me. Could you add a simple test rwith while loop in the model?
I added a simple test test/infer/test_svi.py::test_forward_mode_differentiation
in https://github.com/pyro-ppl/numpyro/pull/1731/commits/532fb5b186d445fda626c3eb682a3d8b666dfb21 but it is failing with
@functools.wraps(update)
def tree_update(i, grad_tree, opt_state):
states_flat, tree, subtrees = opt_state
grad_flat, tree2 = tree_flatten(grad_tree)
if tree2 != tree:
msg = ("optimizer update function was passed a gradient tree that did "
"not match the parameter tree structure with which it was "
"initialized: parameter tree {} and grad tree {}.")
> raise TypeError(msg.format(tree, tree2))
E TypeError: optimizer update function was passed a gradient tree that did not match the parameter tree structure with which it was initialized: parameter tree PyTreeDef({'loc': *, 'scale': *}) and grad tree PyTreeDef(({'loc': *, 'scale': *}, None))
And I am not sure if its because of the implementation of the test is wrong 😑 . Any tips? Thanks
Thank you for your guidance @fehiepsi 🙏🙂
Closes https://github.com/pyro-ppl/numpyro/issues/1726
Trying here a "good first issue".