rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

non-hashable static_argnums are not supported in Jax 0.2.6 #58

Closed jeremiecoullon closed 3 years ago

jeremiecoullon commented 3 years ago

Jax 0.2.6 stops supporting non-hashable static arguments for jitted functions. See changelog. That change log points to an explanation here.

This means that arrays cannot be set as static arguments in jitted functions. An example of a functions that does this at the moment: in inference.warmup.num_steps_adaptation.py, the function run() defined in the longest_batch_before_turn() function sets arrays as static.

This means that cloning mcx and running it with an up-to-date version of jax doesn't work.

rlouf commented 3 years ago

Good catch! There's another incompatibility with newer versions of JAX (#36), so I pinned the version temporarily. That's the price to pay when you use a young library :)

Good news is these are easily fixable and won't impact performance 😅