scverse / pertpy

Perturbation Analysis in the scverse ecosystem.
https://pertpy.readthedocs.io/en/latest/
MIT License
133 stars 20 forks source link

TypeError: unhashable type: 'numpy.ndarray' when running perturbation_space #669

Open ernesto-iacucci opened 6 days ago

ernesto-iacucci commented 6 days ago

Report

Hi, when I am running the code in "perturbation_space.ipynb" and I get the to section on logistic regression classifier space:

ps = pt.tl.LRClassifierSpace() psadata = ps.compute(adata, embedding_key="X_pca", target_col="perturbation_name") psadata

I get the following error:


TypeError Traceback (most recent call last) Cell In[24], line 2 1 ps = pt.tl.LRClassifierSpace() ----> 2 psadata = ps.compute(adata, embedding_key="X_pca", target_col="perturbation_name") 3 psadata

File ~/anaconda3/lib/python3.10/site-packages/pertpy/tools/_perturbation_space/_discriminator_classifiers.py:83, in LRClassifierSpace.compute(self, adata, target_col, layer_key, embedding_key, test_split_size, max_iter) 81 # Save adata observations for embedding annotations in get_embeddings 82 adata_obs = adata.obs.reset_index(drop=True) ---> 83 adata_obs = adata_obs.groupby(target_col).agg( 84 lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0] 85 ) 87 # Fit a logistic regression model for each perturbation 88 regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/generic.py:1482, in DataFrameGroupBy.aggregate(self, func, engine, engine_kwargs, *args, **kwargs) 1480 gba = GroupByApply(self, [func], args=(), kwargs={}) 1481 try: -> 1482 result = gba.agg() 1484 except ValueError as err: 1485 if "No objects to concatenate" not in str(err):

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/apply.py:193, in Apply.agg(self) 190 return self.agg_dict_like() 191 elif is_list_like(func): 192 # we require a list, but not a 'str' --> 193 return self.agg_list_like() 195 if callable(func): 196 f = com.get_cython_func(func)

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/apply.py:326, in Apply.agg_list_like(self) 318 def agg_list_like(self) -> DataFrame | Series: 319 """ 320 Compute aggregation in the case of a list-like argument. 321 (...) 324 Result of aggregation. 325 """ --> 326 return self.agg_or_apply_list_like(op_name="agg")

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/apply.py:1571, in GroupByApply.agg_or_apply_list_like(self, op_name) 1566 # Only set as_index=True on groupby objects, not Window or Resample 1567 # that inherit from this class. 1568 with com.temp_setattr( 1569 obj, "as_index", True, condition=hasattr(obj, "as_index") 1570 ): -> 1571 keys, results = self.compute_list_like(op_name, selected_obj, kwargs) 1572 result = self.wrap_results_list_like(keys, results) 1573 return result

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/apply.py:385, in Apply.compute_list_like(self, op_name, selected_obj, kwargs) 379 colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index]) 380 args = ( 381 [self.axis, self.args] 382 if include_axis(op_name, colg) 383 else self.args 384 ) --> 385 new_res = getattr(colg, op_name)(func, args, **kwargs) 386 results.append(new_res) 387 indices.append(index)

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/generic.py:257, in SeriesGroupBy.aggregate(self, func, engine, engine_kwargs, *args, *kwargs) 255 kwargs["engine"] = engine 256 kwargs["engine_kwargs"] = engine_kwargs --> 257 ret = self._aggregate_multiple_funcs(func, args, **kwargs) 258 if relabeling: 259 # columns is not narrowed by mypy from relabeling flag 260 assert columns is not None # for mypy

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/generic.py:362, in SeriesGroupBy._aggregate_multiple_funcs(self, arg, *args, *kwargs) 360 for idx, (name, func) in enumerate(arg): 361 key = base.OutputKey(label=name, position=idx) --> 362 results[key] = self.aggregate(func, args, **kwargs) 364 if any(isinstance(x, DataFrame) for x in results.values()): 365 from pandas import concat

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/generic.py:294, in SeriesGroupBy.aggregate(self, func, engine, engine_kwargs, *args, kwargs) 291 return self._python_agg_general(func, *args, *kwargs) 293 try: --> 294 return self._python_agg_general(func, args, kwargs) 295 except KeyError: 296 # KeyError raised in test_groupby.test_basic is bc the func does 297 # a dictionary lookup on group.name, but group name is not 298 # pinned in _python_agg_general, only in _aggregate_named 299 result = self._aggregate_named(func, *args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/generic.py:327, in SeriesGroupBy._python_agg_general(self, func, *args, *kwargs) 324 f = lambda x: func(x, args, **kwargs) 326 obj = self._obj_with_exclusions --> 327 result = self._grouper.agg_series(obj, f) 328 res = obj._constructor(result, name=obj.name) 329 return self._wrap_aggregated_output(res)

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/ops.py:864, in BaseGrouper.agg_series(self, obj, func, preserve_dtype) 857 if not isinstance(obj._values, np.ndarray): 858 # we can preserve a little bit more aggressively with EA dtype 859 # because maybe_cast_pointwise_result will do a try/except 860 # with _from_sequence. NB we are assuming here that _from_sequence 861 # is sufficiently strict that it casts appropriately. 862 preserve_dtype = True --> 864 result = self._aggregate_series_pure_python(obj, func) 866 npvalues = lib.maybe_convert_objects(result, try_float=False) 867 if preserve_dtype:

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/ops.py:885, in BaseGrouper._aggregate_series_pure_python(self, obj, func) 882 splitter = self._get_splitter(obj, axis=0) 884 for i, group in enumerate(splitter): --> 885 res = func(group) 886 res = extract_result(res) 888 if not initialized: 889 # We only do this validation on the first iteration

File ~/anaconda3/lib/python3.10/site-packages/pandas/core/groupby/generic.py:324, in SeriesGroupBy._python_agg_general..(x) 322 alias = com._builtin_table_alias[func] 323 warn_alias_replacement(self, orig_func, alias) --> 324 f = lambda x: func(x, *args, **kwargs) 326 obj = self._obj_with_exclusions 327 result = self._grouper.agg_series(obj, f)

File ~/anaconda3/lib/python3.10/site-packages/pertpy/tools/_perturbation_space/_discriminator_classifiers.py:84, in LRClassifierSpace.compute..(pert_group) 81 # Save adata observations for embedding annotations in get_embeddings 82 adata_obs = adata.obs.reset_index(drop=True) 83 adata_obs = adata_obs.groupby(target_col).agg( ---> 84 lambda pert_group: np.nan if len(set(pert_group)) != 1 else list(set(pert_group))[0] 85 ) 87 # Fit a logistic regression model for each perturbation 88 regression_model = LogisticRegression(max_iter=max_iter, class_weight="balanced")

TypeError: unhashable type: 'numpy.ndarray'

Version information


anndata 0.10.8 pandas 2.2.3 pertpy 0.9.4 scanpy 1.10.3 session_info 1.0.0

Cython 3.0.10 OpenSSL 24.2.1 PIL 10.4.0 PyQt5 NA absl NA accelerate 0.34.2 adjustText 1.2.0 anyio NA arrow 1.3.0 arviz 0.20.0 asttokens NA attr 23.2.0 attrs 23.2.0 babel 2.14.0 blitzgsea NA bokeh 3.5.1 boto3 1.35.16 botocore 1.35.16 bottleneck 1.4.0 brotli NA certifi 2024.07.04 cffi 1.16.0 chardet 5.2.0 charset_normalizer 3.2.0 chex 0.1.87 cloudpickle 2.2.1 colorama 0.4.6 comm 0.2.2 contextlib2 NA cryptography 42.0.8 custom_inherit 2.4.1 cycler 0.12.1 cython 3.0.10 cython_runtime NA cytoolz 0.12.3 dask 2024.7.1 dateutil 2.9.0 debugpy 1.8.2 decorator 5.1.1 decoupler 1.8.0 defusedxml 0.7.1 dill 0.3.8 docrep 0.3.2 equinox 0.11.7 ete3 3.1.3 etils 1.9.4 exceptiongroup 1.2.2 executing 2.0.1 fastjsonschema NA filelock 3.15.4 flax 0.9.0 formulaic 1.0.2 fqdn NA fsspec 2024.6.1 gmpy2 2.1.5 google NA graphlib NA h5py 3.11.0 huggingface_hub 0.25.1 idna 3.7 igraph 0.11.6 importlib_metadata NA interface_meta 1.3.0 ipykernel 6.29.5 ipywidgets 8.1.3 isoduration NA jax 0.4.33 jaxlib 0.4.33 jaxopt NA jaxtyping 0.2.34 jedi 0.18.2 jinja2 3.1.4 jmespath 1.0.1 joblib 1.4.2 json5 0.9.25 jsonpointer 2.0 jsonschema 4.23.0 jsonschema_specifications NA jupyter_events 0.10.0 jupyter_server 2.14.2 jupyterlab_server 2.27.3 kiwisolver 1.4.5 lamin_utils 0.13.6 legacy_api_wrap NA leidenalg 0.10.2 lightning_fabric 1.9.5 lightning_utilities 0.11.7 lineax 0.0.6 llvmlite 0.43.0 loguru 0.7.2 lxml 5.2.2 lz4 4.3.3 markupsafe 2.1.5 matplotlib 3.8.4 matplotlib_inline 0.1.7 ml_collections NA ml_dtypes 0.5.0 mpl_toolkits NA mpmath 1.3.0 msgpack 1.0.8 mudata 0.3.1 multipledispatch 0.6.0 natsort 8.4.0 nbformat 5.10.4 numba 0.60.0 numexpr 2.10.0 numpy 1.26.3 numpyro 0.15.3 opt_einsum 3.4.0 optax 0.2.3 ott 0.4.8 overrides NA packaging 24.1 parso 0.8.4 patsy 0.5.6 pickleshare 0.7.5 pkg_resources NA platformdirs 4.2.2 plotly 5.23.0 ply 3.11 png 0.20220715.0 prometheus_client NA prompt_toolkit 3.0.47 psutil 6.0.0 pubchempy 1.0.4 pure_eval 0.2.3 pyarrow 17.0.0 pycparser 2.22 pydeseq2 0.4.11 pydev_ipython NA pydevconsole NA pydevd 2.9.5 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pydot 3.0.2 pygments 2.18.0 pynndescent 0.5.13 pynvml 11.5.3 pyomo 6.8.0 pyparsing 3.1.2 pyro 1.9.1 pythonjsonlogger NA pytorch_lightning 1.9.5 pytz 2024.1 ray 2.37.0 referencing NA regex 2.5.146 reportlab 4.2.5 requests 2.32.3 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rich NA rpds NA safetensors 0.4.5 scipy 1.14.0 scvi 0.20.3 seaborn 0.13.2 send2trash NA setproctitle 1.2.2 setuptools 70.0.0 sip NA six 1.16.0 sklearn 1.5.1 skmisc 0.5.1 sniffio 1.3.1 socks 1.7.1 sparsecca 0.3.1 sphinxcontrib NA stack_data 0.6.2 statsmodels 0.14.2 sympy 1.13.0 tblib 3.0.0 tensorboard 2.18.0 texttable 1.7.0 threadpoolctl 3.5.0 tlz 0.12.3 tokenizers 0.20.0 toolz 0.12.1 torch 2.4.1+cu121 torchgen NA torchmetrics 1.4.2 tornado 6.4.1 toyplot 2.0.0 toytree 3.0.5 tqdm 4.66.4 traitlets 5.14.3 transformers 4.45.1 triton 3.0.0 typeguard NA typing_extensions NA ujson 5.10.0 umap 0.5.6 uri_template NA urllib3 2.2.2 wcwidth 0.2.13 webcolors 24.6.0 websocket 1.8.0 wrapt 1.16.0 xarray 2024.9.0 xarray_einstats 0.8.0 xxhash NA xyzservices 2024.6.0 yaml 6.0.1 zipp NA zmq 26.0.3 zoneinfo NA zope NA zstandard 0.19.0

IPython 8.26.0 jupyter_client 8.6.2 jupyter_core 5.7.2 jupyterlab 4.2.4 notebook 7.2.1

Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:40:32) [GCC 12.3.0] Linux-5.10.226-214.879.amzn2.x86_64-x86_64-with-glibc2.26

Session information updated at 2024-10-14 22:51

Lilly-May commented 4 days ago

Hi @ernesto-iacucci! Thanks for reporting this issue. I just tried, but I couldn’t reproduce your error. Are you using the dataset from the tutorial (pt.dt.norman_2019()) or your own data?

ernesto-iacucci commented 3 days ago

Hi, I am using: adata = pt.dt.norman_2019()