According to this discussion, linear_util was moved in jax 0.4.24. So trying to import pertpy with this version of jax causes the error. A quick fix is to install a previous version:
mamba install -y conda-forge::jax=0.4.23
Full traceback:
{
"name": "ImportError",
"message": "cannot import name 'linear_util' from 'jax' (/home/icb/vladimir.shitov/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/jax/__init__.py)",
"stack": "---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb Cell 1 line 6
<a href='vscode-notebook-cell://localhost:8081/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a> import pandas as pd
<a href='vscode-notebook-cell://localhost:8081/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a> import seaborn as sns
----> <a href='vscode-notebook-cell://localhost:8081/home/icb/vladimir.shitov/projects/2023_12_COPD_Francesca_analysis/03_data_analysis.ipynb#W0sdnNjb2RlLXJlbW90ZQ%3D%3D?line=5'>6</a> import pertpy as pt
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/pertpy/__init__.py:17
14 warnings.filterwarnings(\"ignore\", category=UserWarning)
16 from . import data as dt
---> 17 from . import plot as pl
18 from . import preprocessing as pp
19 from . import tools as tl
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/pertpy/plot/__init__.py:13
11 from pertpy.plot._milopy import MilopyPlot as milo
12 from pertpy.plot._mixscape import MixscapePlot as ms
---> 13 from pertpy.plot._scgen import JaxscgenPlot as scg
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/pertpy/plot/_scgen.py:7
5 from matplotlib import pyplot
6 from scipy import stats
----> 7 from scvi import REGISTRY_KEYS
10 class JaxscgenPlot:
11 \"\"\"Plotting functions for Jaxscgen.\"\"\"
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/__init__.py:11
8 from ._settings import settings
10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import data, model, external, utils
13 from importlib.metadata import version
15 package_name = \"scvi-tools\"
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/data/__init__.py:25
4 from ._datasets import (
5 annotation_simulation,
6 brainlarge_dataset,
(...)
22 synthetic_iid,
23 )
24 from ._manager import AnnDataManager, AnnDataManagerValidationCheck
---> 25 from ._preprocessing import (
26 add_dna_sequence,
27 organize_cite_seq_10x,
28 organize_multiome_anndatas,
29 poisson_gene_selection,
30 reads_to_fragments,
31 )
32 from ._read import read_10x_atac, read_10x_multiome
34 __all__ = [
35 \"AnnTorchDataset\",
36 \"AnnDataManagerValidationCheck\",
(...)
66 \"cellxgene\",
67 ]
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/data/_preprocessing.py:12
9 import torch
10 from scipy.sparse import issparse
---> 12 from scvi.model._utils import parse_device_args
13 from scvi.utils import dependencies, track
14 from scvi.utils._docstrings import devices_dsp
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/model/__init__.py:2
1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
3 from ._autozi import AUTOZI
4 from ._condscvi import CondSCVI
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/model/_amortizedlda.py:15
13 from scvi.data import AnnDataManager
14 from scvi.data.fields import LayerField
---> 15 from scvi.module import AmortizedLDAPyroModule
16 from scvi.utils import setup_anndata_dsp
18 from .base import BaseModelClass, PyroSviTrainMixin
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/__init__.py:1
----> 1 from ._amortizedlda import AmortizedLDAPyroModule
2 from ._autozivae import AutoZIVAE
3 from ._classifier import Classifier
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/_amortizedlda.py:15
13 from scvi._constants import REGISTRY_KEYS
14 from scvi._types import Tunable
---> 15 from scvi.module.base import PyroBaseModuleClass, auto_move_data
16 from scvi.nn import Encoder
18 _AMORTIZED_LDA_PYRO_MODULE_NAME = \"amortized_lda\"
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/base/__init__.py:1
----> 1 from ._base_module import (
2 BaseMinifiedModeModuleClass,
3 BaseModuleClass,
4 JaxBaseModuleClass,
5 LossOutput,
6 PyroBaseModuleClass,
7 TrainStateWithState,
8 )
9 from ._decorators import auto_move_data, flax_configure
11 __all__ = [
12 \"BaseModuleClass\",
13 \"LossOutput\",
(...)
19 \"BaseMinifiedModeModuleClass\",
20 ]
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/scvi/module/base/_base_module.py:8
5 from dataclasses import field
6 from typing import Any, Callable
----> 8 import flax
9 import jax
10 import jax.numpy as jnp
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/flax/__init__.py:18
1 # Copyright 2022 The Flax Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the \"License\");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
16 \"\"\"Flax API.\"\"\"
---> 18 from . import core
19 from . import jax_utils
20 from . import linen
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/flax/core/__init__.py:15
1 # Copyright 2022 The Flax Authors.
2 #
3 # Licensed under the Apache License, Version 2.0 (the \"License\");
(...)
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
---> 15 from .axes_scan import broadcast as broadcast
16 from .frozen_dict import (
17 FrozenDict as FrozenDict,
18 freeze as freeze,
19 unfreeze as unfreeze
20 )
22 from .tracers import (
23 current_trace as current_trace,
24 trace_level as trace_level,
25 check_trace_level as check_trace_level
26 )
File ~/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/flax/core/axes_scan.py:21
19 import jax
20 from jax import lax
---> 21 from jax import linear_util as lu
22 from jax.interpreters import partial_eval as pe
23 import jax.numpy as jnp
ImportError: cannot import name 'linear_util' from 'jax' (/home/icb/vladimir.shitov/software/miniconda3/envs/2024_01_COPD_analysis/lib/python3.10/site-packages/jax/__init__.py)"
}
Maybe you want to pin dependency to a specific version
Report
According to this discussion,
linear_util
was moved in jax 0.4.24. So trying to import pertpy with this version of jax causes the error. A quick fix is to install a previous version:Full traceback:
Maybe you want to pin dependency to a specific version
Version information
No response