dattalab / keypoint-moseq

https://keypoint-moseq.readthedocs.io
Other
63 stars 25 forks source link

'TypeError: unsupported operand type(s)' when importing kpms in Jupyter #121

Closed MiceOnDrugs closed 6 months ago

MiceOnDrugs commented 6 months ago

Hello!

When I run a cell with the code

import keypoint_moseq as kpms

in Jupyter lab, I get the following error:

TypeError: unsupported operand type(s) for |: 'NoneType' and '_UnionGenericAlias' 

I'm trying to use keypoint-moseq with JupyterLab. My computer has Windows 11, and I'm using WSL2 with Ubuntu as my Linux distro. I downloaded conda version 23.7.4, cloned the repository, then installed the Linux (GPU) conda environment. I registered a globally accessible kernel using the command in the docs (I have tried to import keypoint_moseq on other kernels and encountered the same problem).

I tried importing scikit-learn to see if I would run into similar issues, but I didn't encounter any problems.

Here's the full Error message:

TypeError                                 Traceback (most recent call last)
Cell In[2], line 1
----> 1 import keypoint_moseq as kpms

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/keypoint_moseq/__init__.py:2
      1 # use double-precision by default
----> 2 from jax import config
      4 config.update("jax_enable_x64", True)
      6 # simple warning formatting

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/jax/__init__.py:63
     61 from .core import eval_context as ensure_compile_time_eval
     62 from jax._src.environment_info import print_environment_info as print_environment_info
---> 63 from jax._src.api import (
     64   ad,  # TODO(phawkins): update users to avoid this.
     65   effects_barrier,
     66   block_until_ready,
     67   checkpoint as checkpoint,
     68   checkpoint_policies as checkpoint_policies,
     69   clear_backends as clear_backends,
     70   closure_convert as closure_convert,
     71   curry,  # TODO(phawkins): update users to avoid this.
     72   custom_gradient as custom_gradient,
     73   custom_jvp as custom_jvp,
     74   custom_vjp as custom_vjp,
     75   default_backend as default_backend,
     76   device_count as device_count,
     77   device_get as device_get,
     78   device_put as device_put,
     79   device_put_sharded as device_put_sharded,
     80   device_put_replicated as device_put_replicated,
     81   devices as devices,
     82   disable_jit as disable_jit,
     83   eval_shape as eval_shape,
     84   flatten_fun_nokwargs,  # TODO(phawkins): update users to avoid this.
     85   float0 as float0,
     86   grad as grad,
     87   hessian as hessian,
     88   host_count as host_count,
     89   host_id as host_id,
     90   host_ids as host_ids,
     91   jacobian as jacobian,
     92   jacfwd as jacfwd,
     93   jacrev as jacrev,
     94   jit as jit,
     95   jvp as jvp,
     96   local_device_count as local_device_count,
     97   local_devices as local_devices,
     98   linearize as linearize,
     99   linear_transpose as linear_transpose,
    100   make_jaxpr as make_jaxpr,
    101   named_call as named_call,
    102   named_scope as named_scope,
    103   pmap as pmap,
    104   process_count as process_count,
    105   process_index as process_index,
    106   pure_callback as pure_callback,
    107   pxla,  # TODO(phawkins): update users to avoid this.
    108   remat as remat,
    109   ShapedArray as ShapedArray,
    110   ShapeDtypeStruct as ShapeDtypeStruct,
    111   value_and_grad as value_and_grad,
    112   vjp as vjp,
    113   vmap as vmap,
    114   xla,  # TODO(phawkins): update users to avoid this.
    115   xla_computation as xla_computation,
    116 )
    118 from jax._src.array import (
    119     make_array_from_single_device_arrays as make_array_from_single_device_arrays,
    120     make_array_from_callback as make_array_from_callback,
    121 )
    123 from jax.version import __version__ as __version__

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/jax/_src/api.py:48
     42 from jax.core import eval_jaxpr
     43 from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
     44                            tree_structure, tree_transpose, tree_leaves,
     45                            treedef_is_leaf, treedef_children,
     46                            Partial, PyTreeDef, all_leaves, treedef_tuple)
---> 48 from jax._src import callback as jcb
     49 from jax._src import device_array
     50 from jax._src import dispatch

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/jax/_src/callback.py:26
     24 from jax._src import lib as jaxlib
     25 from jax._src import util
---> 26 from jax._src import dispatch
     27 from jax.interpreters import ad
     28 from jax.interpreters import batching

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/jax/_src/dispatch.py:58
     56 import jax._src.util as util
     57 from jax._src.util import flatten, unflatten
---> 58 from etils import epath
     61 FLAGS = flags.FLAGS
     63 flags.DEFINE_string(
     64     'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
     65     help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
     66          "compiler should be dumped as text files. Optional. If omitted, JAX "
     67          "will not dump IR.")

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/etils/epath/__init__.py:19
     15 """Public API."""
     17 from __future__ import annotations
---> 19 from etils.epath import testing
     20 from etils.epath.abstract_path import Path
     21 from etils.epath.flags import DEFINE_path

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/etils/epath/testing.py:27
     24 from unittest import mock
     26 from etils.epath import backend
---> 27 from etils.epath import gpath
     28 from etils.epath import stat_utils
     29 from etils.epath.typing import PathLike

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/etils/epath/gpath.py:29
     26 import typing
     27 from typing import Any, ClassVar, Iterator, Optional, Type, TypeVar, Union
---> 29 from etils import epy
     30 from etils.epath import abstract_path
     31 from etils.epath import backend as backend_lib

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/etils/epy/__init__.py:22
     19 import sys
     21 from etils.epy.backports import cached_property
---> 22 from etils.epy.binary_import import binary_adhoc
     23 from etils.epy.contextlib import ContextManager
     24 from etils.epy.env_utils import is_notebook

File ~/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages/etils/epy/binary_import.py:55
     49       return True
     50   return False
     53 @contextlib.contextmanager
     54 def binary_adhoc(
---> 55     restrict: None | py_utils.StrOrStrList = None,
     56     verbose: bool = False,
     57     **kwargs: Any,
     58 ) -> Iterator[None]:
     59   yield

TypeError: unsupported operand type(s) for |: 'NoneType' and '_UnionGenericAlias'

Thank you, Forest

calebweinreb commented 6 months ago

Can you send the output of pip show etils ?

MiceOnDrugs commented 6 months ago

Can you send the output of pip show etils ?

Yes, here's the output,

pip show etils

Name: etils
Version: 1.6.0
Summary: Collection of common python utils
Home-page: 
Author: 
Author-email: Conchylicultor <etils@google.com>
License: 
Location: /home/labmice/anaconda3/envs/keypoint_moseq/lib/python3.9/site-packages
Requires: 
Required-by: jax
Note: you may need to restart the kernel to use updated packages.
calebweinreb commented 6 months ago

Hmm interesting. Try downgrading to etils version 1.5.2?

pip install etils==1.5.2
MiceOnDrugs commented 6 months ago

Hmm interesting. Try downgrading to etils version 1.5.2?

pip install etils==1.5.2

This appears to have fixed the issue! Thank you!! :)