dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
68 stars 28 forks source link

Module jax has no attribute typing #42

Closed bramn22 closed 1 year ago

bramn22 commented 1 year ago

Hi,

I installed keypoint-moseq on Linux using pip, but after importing it with "import keypoint_moseq" I get the following error: AttributeError: module 'jax' has no attribute 'typing'". I am using jax 0.3.22 (GPU version).

" ... File ~/.conda/envs/keypoint_moseq/lib/python3.9/site-packages/jax_moseq/utils/distributions.py:3 1 import jax, jax.numpy as jnp, jax.random as jr 2 import tensorflow_probability.substrates.jax.distributions as tfd ----> 3 from dynamax.hidden_markov_model.inference import hmm_posterior_sample 4 from jax_moseq.utils import nan_check 5 na = jnp.newaxis

File ~/.conda/envs/keypoint_moseq/lib/python3.9/site-packages/dynamax/hidden_markov_model/init.py:1 ----> 1 from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions, HMMInitialState, HMMTransitions, HMMParameterSet, HMMPropertySet 2 from dynamax.hidden_markov_model.models.arhmm import LinearAutoregressiveHMM 3 from dynamax.hidden_markov_model.models.bernoulli_hmm import BernoulliHMM

File ~/.conda/envs/keypoint_moseq/lib/python3.9/site-packages/dynamax/hidden_markov_model/models/abstractions.py:2 1 from abc import abstractmethod, ABC ----> 2 from dynamax.ssm import SSM 3 from dynamax.types import Scalar 4 from dynamax.parameters import to_unconstrained, from_unconstrained

File ~/.conda/envs/keypoint_moseq/lib/python3.9/site-packages/dynamax/ssm.py:9 7 from jax import jit, lax, vmap 8 from jax.tree_util import tree_map ----> 9 from jaxtyping import Float, Array, PyTree 10 import optax 11 from tensorflow_probability.substrates.jax import distributions as tfd

File ~/.conda/envs/keypoint_moseq/lib/python3.9/site-packages/jaxtyping/init.py:33 30 del jax 32 # First import some things as normal ---> 33 from .array_types import ( 34 AbstractArray as AbstractArray, 35 AbstractDtype as AbstractDtype, 36 get_array_name_format as get_array_name_format, 37 set_array_name_format as set_array_name_format, 38 ) 39 from .decorator import jaxtyped as jaxtyped 40 from .import_hook import install_import_hook as install_import_hook

File ~/.conda/envs/keypoint_moseq/lib/python3.9/site-packages/jaxtyping/array_types.py:667 665 PRNGKeyArray = Key[jax.Array, "2"] 666 Scalar = Shaped[jax.Array, ""] --> 667 ScalarLike = Shaped[jax.typing.ArrayLike, ""]

AttributeError: module 'jax' has no attribute 'typing'"

Thank you, Bram

calebweinreb commented 1 year ago

Thanks for the heads-up! It looks like jaxtyping released a new version last night that is incompatible with older versions of jax (e.g. 0.3.22). So you can fix the issue by downgrading jaxtyping:

pip install -U jaxtyping==0.2.14

We'll pin the jaxtyping version in the next keypoint-moseq release (https://github.com/dattalab/keypoint-moseq/commit/a00e3ce867cf6f4bf49d44ff657fcdda378d8176)