theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
323 stars 50 forks source link

No module named 'jax.extend' #229

Closed dustin-mullaney closed 4 months ago

dustin-mullaney commented 4 months ago

Previously I was using scarches without issue on a CPU. Wanted to train a model on a GPU, and needed to install jaxlib. Everything stopped working after that, unable to import scarches. I then set up the environment provide on the install page using

git clone https://github.com/theislab/scarches
cd scarches
conda env create -f envs/scarches_linux.yaml
conda activate scarches

I am still getting the following error when I import scarches:

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 3
      1 import scanpy as sc
      2 import torch
----> 3 import scarches as sca
      4 from scarches.dataset.trvae.data_handling import remove_sparsity
      5 import matplotlib.pyplot as plt

File ~/.local/lib/python3.10/site-packages/scarches/__init__.py:1
----> 1 from . import dataset, metrics, trainers, models, zenodo, plotting, utils, classifiers
      3 __author__ = ', '.join([
      4     'Mohammad Lotfollahi',
      5     'Sergei Rybakov',
      6     'Marco Wagenstetter',
      7     'Mohsen Naghipourfar',
      8 ])
     10 __email__ = ', '.join([
     11     'mohammad.lotfollahi@helmholtz-muenchen.de',
     12     'sergei.rybakov@helmholtz-muenchen.de'
     13 ])

File ~/.local/lib/python3.10/site-packages/scarches/models/__init__.py:1
----> 1 from .trvae.trvae import trVAE
      2 from .trvae.trvae_model import TRVAE
      3 from .trvae.adaptors import Adaptor, attach_adaptors

File ~/.local/lib/python3.10/site-packages/scarches/models/trvae/trvae.py:9
      6 import torch.nn.functional as F
      8 from .modules import Encoder, Decoder
----> 9 from .losses import mse, mmd, zinb, nb
     10 from ._utils import one_hot_encoder
     11 from ..base._base import CVAELatentsModelMixin

File ~/.local/lib/python3.10/site-packages/scarches/models/trvae/losses.py:1
----> 1 from scvi.distributions import NegativeBinomial
      2 import torch
      3 from torch.autograd import Variable

File ~/.local/lib/python3.10/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import data, model, external, utils
     13 from importlib.metadata import version
     15 package_name = "scvi-tools"

File ~/.local/lib/python3.10/site-packages/scvi/data/__init__.py:25
      4 from ._datasets import (
      5     annotation_simulation,
      6     brainlarge_dataset,
   (...)
     22     synthetic_iid,
     23 )
     24 from ._manager import AnnDataManager, AnnDataManagerValidationCheck
---> 25 from ._preprocessing import (
     26     add_dna_sequence,
     27     organize_cite_seq_10x,
     28     organize_multiome_anndatas,
     29     poisson_gene_selection,
     30     reads_to_fragments,
     31 )
     32 from ._read import read_10x_atac, read_10x_multiome
     34 __all__ = [
     35     "AnnTorchDataset",
     36     "AnnDataManagerValidationCheck",
   (...)
     66     "cellxgene",
     67 ]

File ~/.local/lib/python3.10/site-packages/scvi/data/_preprocessing.py:12
      9 import torch
     10 from scipy.sparse import issparse
---> 12 from scvi.model._utils import parse_device_args
     13 from scvi.utils import dependencies, track
     14 from scvi.utils._docstrings import devices_dsp

File ~/.local/lib/python3.10/site-packages/scvi/model/__init__.py:2
      1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
      3 from ._autozi import AUTOZI
      4 from ._condscvi import CondSCVI

File ~/.local/lib/python3.10/site-packages/scvi/model/_amortizedlda.py:15
     13 from scvi.data import AnnDataManager
     14 from scvi.data.fields import LayerField
---> 15 from scvi.module import AmortizedLDAPyroModule
     16 from scvi.utils import setup_anndata_dsp
     18 from .base import BaseModelClass, PyroSviTrainMixin

File ~/.local/lib/python3.10/site-packages/scvi/module/__init__.py:1
----> 1 from ._amortizedlda import AmortizedLDAPyroModule
      2 from ._autozivae import AutoZIVAE
      3 from ._classifier import Classifier

File ~/.local/lib/python3.10/site-packages/scvi/module/_amortizedlda.py:15
     13 from scvi._constants import REGISTRY_KEYS
     14 from scvi._types import Tunable
---> 15 from scvi.module.base import PyroBaseModuleClass, auto_move_data
     16 from scvi.nn import Encoder
     18 _AMORTIZED_LDA_PYRO_MODULE_NAME = "amortized_lda"

File ~/.local/lib/python3.10/site-packages/scvi/module/base/__init__.py:1
----> 1 from ._base_module import (
      2     BaseMinifiedModeModuleClass,
      3     BaseModuleClass,
      4     JaxBaseModuleClass,
      5     LossOutput,
      6     PyroBaseModuleClass,
      7     TrainStateWithState,
      8 )
      9 from ._decorators import auto_move_data, flax_configure
     11 __all__ = [
     12     "BaseModuleClass",
     13     "LossOutput",
   (...)
     19     "BaseMinifiedModeModuleClass",
     20 ]

File ~/.local/lib/python3.10/site-packages/scvi/module/base/_base_module.py:8
      5 from dataclasses import field
      6 from typing import Any, Callable
----> 8 import flax
      9 import jax
     10 import jax.numpy as jnp

File ~/.local/lib/python3.10/site-packages/flax/__init__.py:24
     21 config: configurations.Config = configurations.config
     22 del configurations
---> 24 from flax import core
     25 from flax import jax_utils
     26 from flax import linen

File ~/.local/lib/python3.10/site-packages/flax/core/__init__.py:15
      1 # Copyright 2024 The Flax Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
---> 15 from .axes_scan import broadcast as broadcast
     16 from .frozen_dict import (
     17     FrozenDict as FrozenDict,
     18     copy as copy,
   (...)
     22     unfreeze as unfreeze,
     23 )
     24 from .lift import (
     25     custom_vjp as custom_vjp,
     26     jit as jit,
   (...)
     33     while_loop as while_loop,
     34 )

File ~/.local/lib/python3.10/site-packages/flax/core/axes_scan.py:23
     21 import numpy as np
     22 from jax import core, lax
---> 23 from jax.extend import linear_util as lu
     24 from jax.interpreters import partial_eval as pe
     26 ScanAxis = Optional[int]

ModuleNotFoundError: No module named 'jax.extend'
dustin-mullaney commented 4 months ago

Turned out to be an inssue with how I installed the conda environment on my institutions HPC system. In case it is ever useful for anyone, adding the environment to jupyter as an ipython kernel helped me resolve this:

python -m ipykernel install --name=scarches