jax-ml / bayeux

State of the art inference for your bayesian models.
https://jax-ml.github.io/bayeux/
Apache License 2.0
135 stars 5 forks source link

TypeError: clip() got an unexpected keyword argument 'max' #54

Closed dirknbr closed 1 month ago

dirknbr commented 1 month ago

This code ran fine some months ago but now fails, this was in a cpu colab

idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0), num_draws=1000, num_chains=2)

[/usr/local/lib/python3.10/dist-packages/jax/_src/linear_util.py](https://localhost:8080/#) in call_wrapped(self, *args, **kwargs)
    190 
    191     try:
--> 192       ans = self.f(*args, **dict(self.params, **kwargs))
    193     except:
    194       # Some transformations yield from inside context managers, so we have to

TypeError: clip() got an unexpected keyword argument 'max'
ColCarroll commented 1 month ago

I might need more details on versions of things and the entire stack trace, but as a guess it looks like this issue.

The fix would be to bump the version of jax you're running to past 0.4.27 (pip install -U "jax>=0.4.27", I think should do it)

dirknbr commented 1 month ago

thanks, that has indeed fixed the error