scverse / pertpy

Perturbation Analysis in the scverse ecosystem.
https://pertpy.readthedocs.io/en/latest/
MIT License
92 stars 19 forks source link

ImportError: cannot import name 'linear_util' from 'jax' #545

Closed VladimirShitov closed 4 months ago

VladimirShitov commented 4 months ago

Report

According to this discussion, linear_util was moved in jax 0.4.24. So trying to import pertpy with this version of jax causes the error. A quick fix is to install a previous version:

mamba install -y conda-forge::jax=0.4.23

Full traceback:

{
    "name": "ImportError",
    "message": "cannot import name 'linear_util' from 'jax' (/home/icb/vladimir.shitov/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/jax/__init__.py)",
    "stack": "---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb Cell 1 line 6
      <a href='vscode-notebook-cell://localhost:8081/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> import pandas as pd
      <a href='vscode-notebook-cell://localhost:8081/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> import seaborn as sns
----> <a href='vscode-notebook-cell://localhost:8081/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> import pertpy as pt

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/pertpy/__init__.py:17
     14 warnings.filterwarnings(\"ignore\", category=UserWarning)
     16 from . import data as dt
---> 17 from . import plot as pl
     18 from . import preprocessing as pp
     19 from . import tools as tl

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/pertpy/plot/__init__.py:13
     11 from pertpy.plot._milopy import MilopyPlot as milo
     12 from pertpy.plot._mixscape import MixscapePlot as ms
---> 13 from pertpy.plot._scgen import JaxscgenPlot as scg

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/pertpy/plot/_scgen.py:7
      5 from matplotlib import pyplot
      6 from scipy import stats
----> 7 from scvi import REGISTRY_KEYS
     10 class JaxscgenPlot:
     11     \"\"\"Plotting functions for Jaxscgen.\"\"\"

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import data, model, external, utils
     13 from importlib.metadata import version
     15 package_name = \"scvi-tools\"

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/data/__init__.py:25
      4 from ._datasets import (
      5     annotation_simulation,
      6     brainlarge_dataset,
   (...)
     22     synthetic_iid,
     23 )
     24 from ._manager import AnnDataManager, AnnDataManagerValidationCheck
---> 25 from ._preprocessing import (
     26     add_dna_sequence,
     27     organize_cite_seq_10x,
     28     organize_multiome_anndatas,
     29     poisson_gene_selection,
     30     reads_to_fragments,
     31 )
     32 from ._read import read_10x_atac, read_10x_multiome
     34 __all__ = [
     35     \"AnnTorchDataset\",
     36     \"AnnDataManagerValidationCheck\",
   (...)
     66     \"cellxgene\",
     67 ]

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/data/_preprocessing.py:12
      9 import torch
     10 from scipy.sparse import issparse
---> 12 from scvi.model._utils import parse_device_args
     13 from scvi.utils import dependencies, track
     14 from scvi.utils._docstrings import devices_dsp

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/model/__init__.py:2
      1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
      3 from ._autozi import AUTOZI
      4 from ._condscvi import CondSCVI

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/model/_amortizedlda.py:15
     13 from scvi.data import AnnDataManager
     14 from scvi.data.fields import LayerField
---> 15 from scvi.module import AmortizedLDAPyroModule
     16 from scvi.utils import setup_anndata_dsp
     18 from .base import BaseModelClass, PyroSviTrainMixin

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/__init__.py:1
----> 1 from ._amortizedlda import AmortizedLDAPyroModule
      2 from ._autozivae import AutoZIVAE
      3 from ._classifier import Classifier

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/_amortizedlda.py:15
     13 from scvi._constants import REGISTRY_KEYS
     14 from scvi._types import Tunable
---> 15 from scvi.module.base import PyroBaseModuleClass, auto_move_data
     16 from scvi.nn import Encoder
     18 _AMORTIZED_LDA_PYRO_MODULE_NAME = \"amortized_lda\"

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/base/__init__.py:1
----> 1 from ._base_module import (
      2     BaseMinifiedModeModuleClass,
      3     BaseModuleClass,
      4     JaxBaseModuleClass,
      5     LossOutput,
      6     PyroBaseModuleClass,
      7     TrainStateWithState,
      8 )
      9 from ._decorators import auto_move_data, flax_configure
     11 __all__ = [
     12     \"BaseModuleClass\",
     13     \"LossOutput\",
   (...)
     19     \"BaseMinifiedModeModuleClass\",
     20 ]

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/base/_base_module.py:8
      5 from dataclasses import field
      6 from typing import Any, Callable
----> 8 import flax
      9 import jax
     10 import jax.numpy as jnp

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/flax/__init__.py:18
      1 # Copyright 2022 The Flax Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the \"License\");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     16 \"\"\"Flax API.\"\"\"
---> 18 from . import core
     19 from . import jax_utils
     20 from . import linen

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/flax/core/__init__.py:15
      1 # Copyright 2022 The Flax Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the \"License\");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .axes_scan import broadcast as broadcast
     16 from .frozen_dict import (
     17   FrozenDict as FrozenDict,
     18   freeze as freeze,
     19   unfreeze as unfreeze
     20 )
     22 from .tracers import (
     23   current_trace as current_trace,
     24   trace_level as trace_level,
     25   check_trace_level as check_trace_level
     26 )

File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/flax/core/axes_scan.py:21
     19 import jax
     20 from jax import lax
---> 21 from jax import linear_util as lu
     22 from jax.interpreters import partial_eval as pe
     23 import jax.numpy as jnp

ImportError: cannot import name 'linear_util' from 'jax' (/home/icb/vladimir.shitov/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/jax/__init__.py)"
}

Maybe you want to pin dependency to a specific version

Version information

No response

Zethson commented 4 months ago

Upstream in scvi-tools. I'll annoy them