theislab / scvelo

RNA Velocity generalized through dynamical modeling
https://scvelo.org
BSD 3-Clause "New" or "Revised" License
408 stars 103 forks source link

module 'jax.random' has no attribute 'KeyArray' #1186

Closed spatts14 closed 6 months ago

spatts14 commented 6 months ago

I cannot import scvelo as not compatible with most updated jax (0.4.24) and jaxlib (0.4.24). What version do I need to have to be compatible? ...

import scvelo as scv
Error output ```pytb AttributeError: module 'jax.random' has no attribute 'KeyArray' ```
Versions ```pytb Version: 0.3.1 ```
christophechu commented 6 months ago

same question

Zethson commented 6 months ago

Install scvi-tools from main. We'll make a new release soon that fixes this.

hvgogogo commented 6 months ago

Install scvi-tools from main. We'll make a new release soon that fixes this.

could you indicate how to install scvi-tools from main, thanks so much.

Zethson commented 6 months ago

@hvgogogo

  1. you git clone the repository and cd into it
  2. pip install -U .

If this is unclear, please consults your favorite search engine or LLM

christophechu commented 6 months ago

@hvgogogo @Zethson just using jax == 0.4.19 its the error from scvi

hvgogogo commented 6 months ago

@hvgogogo @Zethson just using jax == 0.4.19 its the error from scvi

Thanks so much, It works. just a reminder to the following ones, you need downgrade the pip install jaxlib==0.4.19 too.

christophechu commented 6 months ago

SCVI will upgrade to 1.1.0. They will fix this problem soon.

spatts14 commented 6 months ago

The workaround works on my desktop however, I'm trying to run it on the HPC and did the install per instructed and am running into errors.

I installed doing the following:

        git clone https://github.com/scverse/scvi-tools.git 
        pip install -U scvi-tools
        pip install jax==0.4.19

I am getting the following error after running import scvelo:


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 1
----> 1 import scvelo

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/__init__.py:5
      2 from anndata import AnnData
      3 from scanpy import read, read_loom
----> 5 from scvelo import datasets, logging
      6 from scvelo import plotting as pl
      7 from scvelo import preprocessing as pp

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/datasets/__init__.py:1
----> 1 from ._datasets import (
      2     bonemarrow,
      3     dentategyrus,
      4     dentategyrus_lamanno,
      5     forebrain,
      6     gastrulation,
      7     gastrulation_e75,
      8     gastrulation_erythroid,
      9     pancreas,
     10     pancreatic_endocrinogenesis,
     11     pbmc68k,
     12     toy_data,
     13 )
     14 from ._simulate import simulation
     16 __all__ = [
     17     "bonemarrow",
     18     "dentategyrus",
   (...)
     28     "toy_data",
     29 ]

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/datasets/_datasets.py:10
      6 import pandas as pd
      8 from scanpy import read
---> 10 from scvelo.core import cleanup
     11 from scvelo.read_load import load
     13 url_datadir = "https://github.com/theislab/scvelo_notebooks/raw/master/"

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/core/__init__.py:1
----> 1 from ._anndata import (
      2     clean_obs_names,
      3     cleanup,
      4     get_df,
      5     get_initial_size,
      6     get_modality,
      7     get_size,
      8     make_dense,
      9     make_sparse,
     10     merge,
     11     set_initial_size,
     12     set_modality,
     13     show_proportions,
     14 )
     15 from ._arithmetic import clipped_log, invert, multiply, prod_sum, sum
     16 from ._linear_models import LinearRegression

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/core/_anndata.py:15
     11 from scipy.sparse import csr_matrix, issparse, spmatrix
     13 from anndata import AnnData
---> 15 from scvelo import logging as logg
     16 from ._arithmetic import sum
     17 from ._utils import deprecated_arg_names

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/logging.py:12
      8 from packaging.version import parse
     10 from anndata.logging import get_memory_usage
---> 12 from scvelo import settings
     14 _VERBOSITY_LEVELS_FROM_STRINGS = {"error": 0, "warn": 1, "info": 2, "hint": 3}
     17 def info(*args, **kwargs):

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/settings.py:91
     83 """See set_figure_params.
     84 """
     87 # --------------------------------------------------------------------------------
     88 # Functions
     89 # --------------------------------------------------------------------------------
---> 91 warnings.filterwarnings("ignore", category=cbook.mplDeprecation)
     94 # default matplotlib 2.0 palette slightly modified.
     95 vega_10 = list(map(colors.to_hex, cm.tab10.colors))

AttributeError: module 'matplotlib.cbook' has no attribute 'mplDeprecation'

I've tried making a new conda environment and reinstalling per described, but am running into the same errors. I am also am getting an error with import scvi

/rds/general/user/sep22/home/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/_settings.py:63: UserWarning: Since v1.0.0, scvi-tools no longer uses a random seed by default. Run `scvi.settings.seed = 0` to reproduce results from previous versions.
  self.seed = seed
/rds/general/user/sep22/home/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/_settings.py:70: UserWarning: Setting `dl_pin_memory_gpu_training` is deprecated in v1.0 and will be removed in v1.1. Please pass in `pin_memory` to the data loaders instead.
  self.dl_pin_memory_gpu_training = (
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 import scvi

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import autotune, data, model, external, utils, criticism
     13 from importlib.metadata import version
     15 package_name = "scvi-tools"

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/_manager.py:11
      9 import lightning.pytorch as pl
     10 import rich
---> 11 from chex import dataclass
     13 try:
     14     import ray

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/pytypes.py:54
     52 Numeric = Union[Array, Scalar]
     53 Shape = jax.core.Shape
---> 54 PRNGKey = jax.random.KeyArray
     55 PyTreeDef = jax.tree_util.PyTreeDef
     56 Device = jax.Device

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
     51   warnings.warn(message, DeprecationWarning, stacklevel=2)
     52   return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.random' has no attribute 'KeyArray'

I then tried upgrading

pip install --upgrade matplotlib scvelo

and now Im getting this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 import scvelo

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/__init__.py:6
      3 from scanpy import read, read_loom
      5 from scvelo import datasets, logging
----> 6 from scvelo import plotting as pl
      7 from scvelo import preprocessing as pp
      8 from scvelo import settings

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/__init__.py:3
      1 from scanpy.plotting import paga_compare, rank_genes_groups
----> 3 from .gridspec import gridspec
      4 from .heatmap import heatmap
      5 from .paga import paga

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/gridspec.py:6
      3 import matplotlib.pyplot as pl
      5 # todo: auto-complete and docs wrapper
----> 6 from .scatter import scatter
      7 from .utils import get_figure_params, hist
      8 from .velocity_embedding import velocity_embedding

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/scatter.py:16
     14 from scvelo.preprocessing.neighbors import get_connectivities
     15 from .docs import doc_params, doc_scatter
---> 16 from .utils import (
     17     default_basis,
     18     default_color,
     19     default_color_map,
     20     default_legend_loc,
     21     default_size,
     22     default_xkey,
     23     default_ykey,
     24     get_ax,
     25     get_components,
     26     get_figure_params,
     27     get_kwargs,
     28     get_obs_vector,
     29     get_value_counts,
     30     gets_vals_from_color_gradients,
     31     groups_to_bool,
     32     interpret_colorkey,
     33     is_categorical,
     34     is_int,
     35     is_list,
     36     is_list_of_int,
     37     is_list_of_list,
     38     is_list_of_str,
     39     make_dense,
     40     plot_density,
     41     plot_linfit,
     42     plot_outline,
     43     plot_polyfit,
     44     plot_rug,
     45     plot_velocity_fits,
     46     rgb_custom_colormap,
     47     savefig_or_show,
     48     set_colorbar,
     49     set_colors_for_categorical_obs,
     50     set_label,
     51     set_legend,
     52     set_margin,
     53     set_title,
     54     to_list,
     55     to_val,
     56     to_valid_bases_list,
     57     update_axes,
     58 )
     61 @doc_params(scatter=doc_scatter)
     62 def scatter(
     63     adata=None,
   (...)
    122     **kwargs,
    123 ):
    124     """Scatter plot along observations or variables axes.
    125 
    126     Arguments:
   (...)
    138     If `show==False` a `matplotlib.Axis`
    139     """

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/utils.py:24
     22 from scvelo import logging as logg
     23 from scvelo import settings
---> 24 from scvelo.tools.utils import strings_to_categoricals
     25 from . import palettes
     27 """helper functions"""

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/tools/__init__.py:14
      4 from ._em_model_core import (
      5     align_dynamics,
      6     differential_kinetic_test,
   (...)
     11     recover_latent_time,
     12 )
     13 from ._steady_state_model import SecondOrderSteadyStateModel, SteadyStateModel
---> 14 from ._vi_model import VELOVI
     15 from .paga import paga
     16 from .rank_velocity_genes import rank_velocity_genes, velocity_clusters

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/tools/_vi_model.py:14
     11 from scipy.stats import ttest_ind
     13 from anndata import AnnData
---> 14 from scvi.data import AnnDataManager
     15 from scvi.data.fields import LayerField
     16 from scvi.dataloaders import DataSplitter

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import autotune, data, model, external, utils, criticism
     13 from importlib.metadata import version
     15 package_name = "scvi-tools"

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/_manager.py:11
      9 import lightning.pytorch as pl
     10 import rich
---> 11 from chex import dataclass
     13 try:
     14     import ray

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/pytypes.py:19
     15 """Type definitions to use for type annotations."""
     17 from typing import Any, Iterable, Mapping, Union
---> 19 import jax
     20 import jax.numpy as jnp
     21 import numpy as np

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/__init__.py:39
     34 del _cloud_tpu_init
     36 # Confusingly there are two things named "config": the module and the class.
     37 # We want the exported object to be the class, so we first import the module
     38 # to make sure a later import doesn't overwrite the class.
---> 39 from jax import config as _config_module
     40 del _config_module
     42 # Force early import, allowing use of `jax.core` after importing `jax`.

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/config.py:17
      1 # Copyright 2018 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 
     15 # TODO(phawkins): fix users of this alias and delete this file.
---> 17 from jax._src.config import config  # noqa: F401

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/_src/config.py:27
     24 import threading
     25 from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
---> 27 from jax._src import lib
     28 from jax._src.lib import jax_jit
     29 from jax._src.lib import transfer_guard_lib

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/_src/lib/__init__.py:75
     70   return _jaxlib_version
     73 version_str = jaxlib.version.__version__
     74 version = check_jaxlib_version(
---> 75   jax_version=jax.version.__version__,
     76   jaxlib_version=jaxlib.version.__version__,
     77   minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
     79 # Before importing any C compiled modules from jaxlib, first import the CPU
     80 # feature guard module to verify that jaxlib was compiled in a way that only
     81 # uses instructions that are present on this machine.
     82 import jaxlib.cpu_feature_guard as cpu_feature_guard

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

Any thoughts?

aterceros commented 6 months ago

Hi, Having the same issue, downgraded to jax==0.4.19 but still having the "module 'jax.random' has no attribute 'KeyArray'" error...any hints?

christophechu commented 6 months ago

@aterceros @spatts14 upgrade scvi to 1.1.0 may solve this issue.

bio-la commented 6 months ago

has anyone tested if this issue with MultiVI is also sorted in scvi 1.1.0? https://discourse.scverse.org/t/error-when-training-model-on-m3-max-mps/1896 multivi breaks with 1.0.4 (no issues with scvi,totalvi)

spatts14 commented 6 months ago

Solution:

I updated to scvi-tools 1.1.1 and upgraded pandas.

Current versions: scanpy 1.9.8
scvelo 0.3.1
scvi-tools 1.1.1

Now running on the HPC.

Of note, I use Parse data. The Parse website says to use pandas==1.5.3, however I needed to update.

aterceros commented 6 months ago

Hi, Thank you for the comments, I upgraded to scvi-tools 1.1.0post2 (which seems to be the latest, no 1.1.1. version), but error still persisting. @spatts14 what version of pandas do you have? Thanks a lot!

spatts14 commented 6 months ago

I have pandas== 2.2.0, but I also have version scvi-tools 1.1.1 (released yesterday). I would suggest setting up a new environment and install from the main

epignatelli commented 6 months ago

Just bumped here by total chance while searching for the error.

This might help: https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys

KeyArray was deprecated and has now been removed, use jax.Array instead.

dvagbear commented 5 months ago

Just bumped here by total chance while searching for the error.

This might help: https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys

it is a quick resolution that change jax,random.KeyArray into jax.Array in chex/_src/pytypes.py as the error hint

Xxianna commented 5 months ago

Bug still exists in 1.1.2 😭 with jax==0.4.25 Use PRNGKey = jax.random.PRNGKey but not PRNGKey = jax.random.KeyArray😀 It will probably work, I see they have the same parameters, according to the jax documentation provided upstairs. At least my init worked. The change occurred in \site-packages\chex\_src\pytypes.py

WeilerP commented 5 months ago

You can now pip install scvelo (scvelo>=0.3.2) without scvi and jax as a dependency.

littlewhitesea commented 4 months ago

I met with a same problem and solved it through the following command line from stackoverflow.

pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
canergen commented 4 months ago

Just seeing this. The dependency of Chex is removed in scVI-tools 1.1.0 and the error is not supposed to happen. If your still facing it, I would recommend setting up a new environment or otherwise report directly at scVI-tools.