lindermanlab / ssm-jax

Bayesian learning and inference for state space models (SSMs) using Google Research's JAX as a backend
MIT License
57 stars 7 forks source link

installation: avoid specifying jax==0.2.21 #20

Closed murphyk closed 2 years ago

murphyk commented 2 years ago

In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/setup.py#L21, replace jax==0.2.21 with jax>=0.2.21. Since in colab, it uninstalls the default jax 0.2.25.

slinderman commented 2 years ago

@schlagercollin Remind me why we were hardcoding the JAX version? Can we get rid of this constraint now?

schlagercollin commented 2 years ago

Good find. Resolved. I believe we had it pinned for some experimental features, though we don't seem to have a hard dependency on them.

One thing: I'll have to find the replacement to jax.interpreters.xla._xla_callable.cache_clear() since it seems that's changed since the version upgrade. We were only using this in the timing tests to prevent OOM errors after lots of JIT compilations; I've just commented it out for now.