Closed ColCarroll closed 4 months ago
Noticed that jax.tree.<whatever> is now being used here, which was introduced in jax[lib] 0.4.25.
jax.tree.<whatever>
Need to install older version of JAX, as in
https://colab.research.google.com/drive/1K5-nH6NXJY-KggtSzbt0_kudUfR0d_2V?usp=sharing
then any code that hits a jax.tree... call will throw
jax.tree...
import jax import blackjax blackjax.optimizers.lbfgs.minimize_lbfgs(lambda x: x*x, -1.)
no exception
AttributeError Traceback (most recent call last) <ipython-input-3-6ef54887e135> in <cell line: 4>() 2 import blackjax 3 ----> 4 blackjax.optimizers.lbfgs.minimize_lbfgs(lambda x: x*x, -1.) 2 frames /usr/local/lib/python3.10/dist-packages/blackjax/optimizers/lbfgs.py in minimize_lbfgs(fun, x0, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs) 106 107 # Run LBFGS optimizer on flat input. --> 108 last_step_raveled, history_raveled = _minimize_lbfgs( 109 lambda x: fun(unravel_fn(x)), 110 x0_raveled, /usr/local/lib/python3.10/dist-packages/blackjax/optimizers/lbfgs.py in _minimize_lbfgs(fun, x0, maxiter, maxcor, gtol, ftol, maxls, **lbfgs_kwargs) 231 ) 232 # Append initial state to history. --> 233 history = jax.tree.map( 234 lambda x, y: jnp.concatenate([x[None, ...], y], axis=0), 235 initial_history, /usr/local/lib/python3.10/dist-packages/jax/_src/deprecations.py in getattr(name) 51 warnings.warn(message, DeprecationWarning, stacklevel=2) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}") 54 55 return getattr AttributeError: module 'jax' has no attribute 'tree'
### Blackjax/JAX/jaxlib/Python version information: ```python blackjax nightly jax, jaxlib 0.4.24
No response
I guess we can be explicit and update the dependence requirement?
Thanks -- I'd been trying to figure out when to update bayeux to jax.tree.foo, and then realized i probably already have to!
bayeux
jax.tree.foo
Describe the issue as clearly as possible:
Noticed that
jax.tree.<whatever>
is now being used here, which was introduced in jax[lib] 0.4.25.Steps/code to reproduce the bug:
Need to install older version of JAX, as in
https://colab.research.google.com/drive/1K5-nH6NXJY-KggtSzbt0_kudUfR0d_2V?usp=sharing
then any code that hits a
jax.tree...
call will throwExpected result:
no exception
Error message:
Context for the issue:
No response