Open dfm opened 3 years ago
I think numpyro is imported when needed, not at arviz import.
@ahartikainen: Interesting - but there's no issue importing it from my main code!
Can you share exactly how you're installing stuff? This throws a more sane error, and has jax
version 0.2.6, which looks like it matches yours.
@ColCarroll: Sure! I go to https://colab.research.google.com, create a new notebook, and then put exactly the code above into it.
Ref: https://colab.research.google.com/drive/1F1NNJbOoQnJ5b4D7yGOun8yCUBVv8CSd?usp=sharing
Wow! Can repro, and can confirm that the same behavior happens with PRNGKey 86074.
More seriously, the bug happens here, when trying to grab the library version to stamp the inference data. This definitely shouldn't raise, and indicates that we're doing something wrong with pkg_resources
, and maybe with lazy-loading libraries.
Shorter term, if you import numpyro
before arviz
, it works fine.
I find that changing the order of the imports doesn't fix the issue, but it does change the specific jax it finds: jax 0.2.6
instead of jaxlib 0.1.57+cuda101
.
Edit to add screenshot:
Wow, here's something terrible, then. Restarting then running again fixes it, with the imports in either order.
https://github.com/rasbt/watermark/issues/15
Apparently installing Jupyter fixes this --> colab is not Jupyter.
Edit. Probably something colab specific
@ColCarroll: I find that the truth of your statement is somewhat stochastic :D, but yes, sometimes restarting and running again works (but not always)! Edited to add: In my experience this always fails after a "Factory reset" of the runtime.
@ahartikainen: Agreed - I haven't tried this on other platforms. I do think it's worth trying to figure out what's going on though because colab is a pretty good place to test all these GPU-based MCMC samplers and it's unfortunate if trying to use arviz crashes everything!
Is this still happening?
Describe the bug This is the strangest thing and I haven't been able to entirely work out what's going on but perhaps y'all have thoughts!
When using Google colab, I install numpyro with
--no-deps
so that it doesn't try to update JAX to the CPU-only version. Then, everything runs fine (the sampler runs and I get my trace) until I executeaz.from_numpyro()
, at which point it blows up with the following exception:This doesn't really seem to be an ArviZ issue, but everything is fine until I try to use ArviZ so perhaps you have thoughts.
To Reproduce In a factor reset colab environment I run the following and get the above exception on the last line:
Expected behavior I would expect this exception to be thrown earlier if it is actually a problem, or for ArviZ to run as expected!