theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
323 stars 50 forks source link

environment reproducibility #219

Closed racng closed 7 months ago

racng commented 8 months ago

Can I have the list of essential package versions installed in your dev environment?

I have tried installing all requirements using the yaml file provided in this repo; however, it ends up installing cpu versions of torchvision and torchaudio, which prevents the usage of our GPU. If I follow pytorch instructions to install pytorch with CUDA 11.7 (conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia), I get cuda versions of torchvision and torchaudio installed, but I run into jax errors when trying to import scArches.

import scarches

INFO:lightning_fabric.utilities.seed:Global seed set to 0
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scarches/__init__.py", line 1, in <module>
    from . import dataset, metrics, trainers, models, zenodo, plotting, utils, classifiers
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scarches/models/__init__.py", line 1, in <module>
    from .trvae.trvae import trVAE
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scarches/models/trvae/trvae.py", line 9, in <module>
    from .losses import mse, mmd, zinb, nb
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scarches/models/trvae/losses.py", line 1, in <module>
    from scvi.distributions import NegativeBinomial
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/__init__.py", line 10, in <module>
    from . import autotune, data, model, external, utils
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/autotune/__init__.py", line 1, in <module>
    from ._manager import TuneAnalysis, TunerManager
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/autotune/_manager.py", line 22, in <module>
    from scvi.model.base import BaseModelClass
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/model/__init__.py", line 2, in <module>
    from ._amortizedlda import AmortizedLDA
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/model/_amortizedlda.py", line 14, in <module>
    from scvi.module import AmortizedLDAPyroModule
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/module/__init__.py", line 1, in <module>
    from ._amortizedlda import AmortizedLDAPyroModule
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/module/_amortizedlda.py", line 14, in <module>
    from scvi.module.base import PyroBaseModuleClass, auto_move_data
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/module/base/__init__.py", line 1, in <module>
    from ._base_module import (
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/scvi/module/base/_base_module.py", line 8, in <module>
    import flax
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/flax/__init__.py", line 20, in <module>
    from . import linen
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/flax/linen/__init__.py", line 47, in <module>
    from .attention import (
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/flax/linen/attention.py", line 22, in <module>
    from flax.linen.linear import default_kernel_init
  File "/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/flax/linen/linear.py", line 30, in <module>
    from jax import ShapedArray
ImportError: cannot import name 'ShapedArray' from 'jax' (/users/rng/mambaforge/envs/scai/lib/python3.9/site-packages/jax/__init__.py)

YAML file used

name: scai
channels:
  - pytorch
  - nvidia
  - conda-forge
dependencies:
  - python=3.9
  - ipykernel
  - pip
  - numpy
  - pytorch=2.0.1
  - torchaudio=2.0.2
  - torchvision=0.15.2
  - pytorch-cuda=11.7
  - scvi-tools
  - pip:
    - scarches

What exact versions of the following packages in combination would work for scarches?

python
pytorch
torchvision
torchaudio
pytorch-cuda
jax
jaxlib
chex
flax
scvi-tools

Thanks for your help!

LamineTourelab commented 8 months ago

Hi @racng I had the same problem when installing scarches using the yaml. The problem for me was that the cuda driver did not get the specific cuda version they needed. Check the nvidia-smi command line to see which version of cuda do you need. For me it was the 12.1

name: scarches
channels:
  - pytorch
  - nvidia
  - defaults
dependencies:
  - python=3.9
  - pip
  - numpy
  - pytorch
  - torchaudio
  - torchvision
  - pytorch-cuda=12.1
  - pip:
    - scvi-tools
    - scarches
variables:
  SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL: True

It may help.

Koncopd commented 7 months ago

The problem was in pytorch-cuda=11.7, i updated it to pytorch-cuda=11.8. Now it should work, feel free to reopen if it doesn't.