Closed ElhamAfzali closed 2 years ago
Hello! Thanks for pointing out this issue.
What's the stack trace for this error? Specifically, which line in the code causes this (in optimizer.py
for example)? And what version of sgmcmcjax are you using?
Also, do you have some minimal steps to reproduce?
Note that this script has an example use of optimizers to follow.
I would recommend using the build_optax_optimizer
; this uses optax which is the official optimization library for JAX. In contrast, build_adam_optimizer
uses a JAX demo optimizers library (which is now moved to jax.example_libraries, and which technically shouldn't be used :) ).
If importing this doesn't work, you could try as a quick fix building the build_optax_optimizer
function manually by copying the code for it in your script (and also importing the relevant stuff). This uses some building blocks in sgmcmcjax
as well as optax
.
As an aside, I need to do some maintenance on sgmcmcjax; I need to update the path of the demo optimizers library so that it still works, and also update some other things. Finally, also note that I might have time to look at this properly next week, but apart from that I'll be busy/away for the next 2 weeks.
I hope this helps! and let me know if this is still blocked!
Hi, thanks for answering. the following picture is the whole error.
I am using the sgmcmcjax version 0.2.9. I am trying to use the manual function. Thanks.
I found the reason. This problem is regarding to the version of the chex library. The version that I used was chex==0.1.5 while the compatible version for your package is chex=0.1.2. :)
Thanks for the update about this!!
Hi, While I am importing ( build_adam_optimizer) I get the following error! AttributeError : module 'jax.random' has no attribute 'KeyArray'