jeremiecoullon / SGMCMCJax

Lightweight library of stochastic gradient MCMC algorithms written in JAX.
https://sgmcmcjax.readthedocs.io/en/latest/index.html
Apache License 2.0
95 stars 8 forks source link

from sgmcmcjax.optimizer import build_adam_optimizer #67

Closed ElhamAfzali closed 2 years ago

ElhamAfzali commented 2 years ago

Hi, While I am importing ( build_adam_optimizer) I get the following error! AttributeError : module 'jax.random' has no attribute 'KeyArray'

jeremiecoullon commented 2 years ago

Hello! Thanks for pointing out this issue.

What's the stack trace for this error? Specifically, which line in the code causes this (in optimizer.py for example)? And what version of sgmcmcjax are you using? Also, do you have some minimal steps to reproduce?

Note that this script has an example use of optimizers to follow. I would recommend using the build_optax_optimizer; this uses optax which is the official optimization library for JAX. In contrast, build_adam_optimizer uses a JAX demo optimizers library (which is now moved to jax.example_libraries, and which technically shouldn't be used :) ).

If importing this doesn't work, you could try as a quick fix building the build_optax_optimizer function manually by copying the code for it in your script (and also importing the relevant stuff). This uses some building blocks in sgmcmcjax as well as optax.

As an aside, I need to do some maintenance on sgmcmcjax; I need to update the path of the demo optimizers library so that it still works, and also update some other things. Finally, also note that I might have time to look at this properly next week, but apart from that I'll be busy/away for the next 2 weeks.

I hope this helps! and let me know if this is still blocked!

ElhamAfzali commented 2 years ago

Hi, thanks for answering. the following picture is the whole error.

Screen Shot 2022-09-16 at 10 55 38

I am using the sgmcmcjax version 0.2.9. I am trying to use the manual function. Thanks.

ElhamAfzali commented 2 years ago

I found the reason. This problem is regarding to the version of the chex library. The version that I used was chex==0.1.5 while the compatible version for your package is chex=0.1.2. :)

jeremiecoullon commented 2 years ago

Thanks for the update about this!!