blackjax-devs / blackjax

BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
https://blackjax-devs.github.io/blackjax/
Apache License 2.0
806 stars 105 forks source link

Blackjax has an implicit dependence on "jax>=0.4.25" and "jaxlib>=0.4.25" #665

Closed ColCarroll closed 4 months ago

ColCarroll commented 4 months ago

Describe the issue as clearly as possible:

Noticed that jax.tree.<whatever> is now being used here, which was introduced in jax[lib] 0.4.25.

Steps/code to reproduce the bug:

Need to install older version of JAX, as in

https://colab.research.google.com/drive/1K5-nH6NXJY-KggtSzbt0_kudUfR0d_2V?usp=sharing

then any code that hits a jax.tree... call will throw

import jax
import blackjax

blackjax.optimizers.lbfgs.minimize_lbfgs(lambda x: x*x, -1.)

Expected result:

no exception

Error message:

AttributeError                            Traceback (most recent call last)
<ipython-input-3-6ef54887e135> in <cell line: 4>()
      2 import blackjax
      3 
----> 4 blackjax.optimizers.lbfgs.minimize_lbfgs(lambda x: x*x, -1.)

2 frames
/usr/local/lib/python3.10/dist-packages/blackjax/optimizers/lbfgs.py in minimize_lbfgs(fun, x0, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs)
    106 
    107     # Run LBFGS optimizer on flat input.
--> 108     last_step_raveled, history_raveled = _minimize_lbfgs(
    109         lambda x: fun(unravel_fn(x)),
    110         x0_raveled,

/usr/local/lib/python3.10/dist-packages/blackjax/optimizers/lbfgs.py in _minimize_lbfgs(fun, x0, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs)
    231     )
    232     # Append initial state to history.
--> 233     history = jax.tree.map(
    234         lambda x, y: jnp.concatenate([x[None, ...], y], axis=0),
    235         initial_history,

/usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py in getattr(name)
     51       warnings.warn(message, DeprecationWarning, stacklevel=2)
     52       return fn
---> 53     raise AttributeError(f"module {module!r} has no attribute {name!r}")
     54 
     55   return getattr

AttributeError: module 'jax' has no attribute 'tree'

### Blackjax/JAX/jaxlib/Python version information:

```python
blackjax nightly
jax, jaxlib 0.4.24

Context for the issue:

No response

junpenglao commented 4 months ago

I guess we can be explicit and update the dependence requirement?

ColCarroll commented 4 months ago

Thanks -- I'd been trying to figure out when to update bayeux to jax.tree.foo, and then realized i probably already have to!