scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.23k stars 349 forks source link

Error Importing - dadapt_adamw; 'type' object is not subscriptable #2424

Closed rigelsison closed 9 months ago

rigelsison commented 9 months ago

I was trying to import and then came across this error. I'm not sure how to address this and any help would be appreciated.

import os
import numpy as np
import pandas as pd
import scipy
import anndata
import scanpy as sc
import pybiomart
import scvi
import torch
import random
import seaborn as sns
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[24], line 8
      6 import scanpy as sc
      7 import pybiomart
----> 8 import scvi
      9 import torch
     10 import random

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/__init__.py:10
      7 from ._settings import settings
      9 # this import needs to come after prior imports to prevent circular import
---> 10 from . import autotune, data, model, external, utils
     12 from importlib.metadata import version
     14 package_name = "scvi-tools"

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/autotune/_manager.py:22
     20 from scvi._types import AnnOrMuData
     21 from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME
---> 22 from scvi.model.base import BaseModelClass
     23 from scvi.utils import InvalidParameterError
     25 from ._defaults import COLORS, COLUMN_KWARGS, DEFAULTS, TUNABLE_TYPES

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/model/__init__.py:2
      1 from . import utils
----> 2 from ._amortizedlda import AmortizedLDA
      3 from ._autozi import AUTOZI
      4 from ._condscvi import CondSCVI

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/model/_amortizedlda.py:14
     12 from scvi.data import AnnDataManager
     13 from scvi.data.fields import LayerField
---> 14 from scvi.module import AmortizedLDAPyroModule
     15 from scvi.utils import setup_anndata_dsp
     17 from .base import BaseModelClass, PyroSviTrainMixin

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/module/__init__.py:1
----> 1 from ._amortizedlda import AmortizedLDAPyroModule
      2 from ._autozivae import AutoZIVAE
      3 from ._classifier import Classifier

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/module/_amortizedlda.py:14
     12 from scvi._constants import REGISTRY_KEYS
     13 from scvi.autotune._types import Tunable
---> 14 from scvi.module.base import PyroBaseModuleClass, auto_move_data
     15 from scvi.nn import Encoder
     17 _AMORTIZED_LDA_PYRO_MODULE_NAME = "amortized_lda"

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/module/base/__init__.py:1
----> 1 from ._base_module import (
      2     BaseMinifiedModeModuleClass,
      3     BaseModuleClass,
      4     JaxBaseModuleClass,
      5     LossOutput,
      6     PyroBaseModuleClass,
      7     TrainStateWithState,
      8 )
      9 from ._decorators import auto_move_data, flax_configure
     11 __all__ = [
     12     "BaseModuleClass",
     13     "LossOutput",
   (...)
     19     "BaseMinifiedModeModuleClass",
     20 ]

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/scvi/module/base/_base_module.py:15
     13 import torch
     14 from flax.core import FrozenDict
---> 15 from flax.training import train_state
     16 from jax import random
     17 from jaxlib.xla_extension import Device

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/flax/training/train_state.py:19
     17 from flax import core
     18 from flax import struct
---> 19 import optax
     22 class TrainState(struct.PyTreeNode):
     23   """Simple train state for the common case with a single Optax optimizer.
     24 
     25   Synopsis::
   (...)
     50     opt_state: The state for `tx`.
     51   """

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/optax/__init__.py:17
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import contrib
     18 from optax import losses
     19 from optax import monte_carlo

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/optax/contrib/__init__.py:21
     19 from optax.contrib.complex_valued import split_real_and_imaginary
     20 from optax.contrib.complex_valued import SplitRealAndImaginaryState
---> 21 from optax.contrib.dadapt_adamw import dadapt_adamw
     22 from optax.contrib.dadapt_adamw import DAdaptAdamWState
     23 from optax.contrib.mechanic import MechanicState

File ~/Library/jupyterlab-desktop/jlab_server/lib/python3.8/site-packages/optax/contrib/dadapt_adamw.py:45
     39   numerator_weighted: chex.Array  # shape=(), dtype=jnp.float32.
     40   count: chex.Array  # shape=(), dtype=jnp.int32.
     43 def dadapt_adamw(
     44     learning_rate: base.ScalarOrSchedule = 1.0,
---> 45     betas: tuple[float, float] = (0.9, 0.999),
     46     eps: float = 1e-8,
     47     estim_lr0: float = 1e-6,
     48     weight_decay: float = 0.,
     49 ) -> base.GradientTransformation:
     50   """Learning rate free AdamW by D-Adaptation.
     51 
     52   Adapts the baseline learning rate of AdamW automatically by estimating the
   (...)
     69     A `GradientTransformation` object.
     70   """
     72   def init_fn(params: base.Params) -> DAdaptAdamWState:

TypeError: 'type' object is not subscriptable
martinkim0 commented 9 months ago

Hi, looks like you are trying to use scvi-tools with Python 3.8. We don't support that version anymore with our newer releases, and this is likely the source of the issue. I would suggest updating to Python 3.9 - 3.11.

rigelsison commented 9 months ago

That fixed the issue, thank you!