pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Support forward mode differentiation for SVI #1731

Closed juanitorduz closed 8 months ago

juanitorduz commented 8 months ago

Closes https://github.com/pyro-ppl/numpyro/issues/1726

Trying here a "good first issue".

fehiepsi commented 8 months ago

Looks great to me. Could you add a simple test rwith while loop in the model?

juanitorduz commented 8 months ago

I added a simple test test/infer/test_svi.py::test_forward_mode_differentiation in https://github.com/pyro-ppl/numpyro/pull/1731/commits/532fb5b186d445fda626c3eb682a3d8b666dfb21 but it is failing with

    @functools.wraps(update)
    def tree_update(i, grad_tree, opt_state):
      states_flat, tree, subtrees = opt_state
      grad_flat, tree2 = tree_flatten(grad_tree)
      if tree2 != tree:
        msg = ("optimizer update function was passed a gradient tree that did "
               "not match the parameter tree structure with which it was "
               "initialized: parameter tree {} and grad tree {}.")
>       raise TypeError(msg.format(tree, tree2))
E       TypeError: optimizer update function was passed a gradient tree that did not match the parameter tree structure with which it was initialized: parameter tree PyTreeDef({'loc': *, 'scale': *}) and grad tree PyTreeDef(({'loc': *, 'scale': *}, None))

And I am not sure if its because of the implementation of the test is wrong 😑 . Any tips? Thanks

juanitorduz commented 8 months ago

Thank you for your guidance @fehiepsi 🙏🙂