Closed murphyk closed 2 years ago
@schlagercollin Remind me why we were hardcoding the JAX version? Can we get rid of this constraint now?
Good find. Resolved. I believe we had it pinned for some experimental features, though we don't seem to have a hard dependency on them.
One thing: I'll have to find the replacement to jax.interpreters.xla._xla_callable.cache_clear()
since it seems that's changed since the version upgrade. We were only using this in the timing tests to prevent OOM errors after lots of JIT compilations; I've just commented it out for now.
In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/setup.py#L21, replace
jax==0.2.21
withjax>=0.2.21
. Since in colab, it uninstalls the default jax 0.2.25.