chanzuckerberg / cellxgene-census

CZ CELLxGENE Discover Census
https://chanzuckerberg.github.io/cellxgene-census/
MIT License
79 stars 20 forks source link

highly_variable_genes doesn't work with coords but only value_filter #937

Closed canergen closed 8 months ago

canergen commented 8 months ago

The highly variable gene function doesn't work with using a query with coords. The same error appears when calling query.n_vars.

query = census["census_data"]["homo_sapiens"].axis_query(
    measurement_name="RNA",
    obs_query=soma.AxisQuery(value_filter="is_primary_data == True"),
    var_query=soma.AxisQuery(coords=np.arange(100)), # not sure this line works
)

hvgs_df = highly_variable_genes(query, n_top_genes=6000, batch_key=["suspension_type", "assay"])

hv = hvgs_df.highly_variable

hv.to_pickle("hv_genes.pkl")
hv_idx = hv[hv].index
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[14], line 1
----> 1 hvgs_df = highly_variable_genes(query, n_top_genes=6000, batch_key=["suspension_type", "assay"])

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/cellxgene_census/experimental/pp/_highly_variable_genes.py:320, in highly_variable_genes(query, n_top_genes, layer, flavor, span, batch_key, max_loess_jitter, batch_key_func)
    317 if flavor != "seurat_v3":
    318     raise ValueError('`flavor` must be "seurat_v3"')
--> 320 return _highly_variable_genes_seurat_v3(
    321     query,
    322     n_top_genes=n_top_genes,
    323     layer=layer,
    324     span=span,
    325     batch_key=batch_key,
    326     batch_key_func=batch_key_func,
    327     max_loess_jitter=max_loess_jitter,
    328 )

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/cellxgene_census/experimental/pp/_highly_variable_genes.py:109, in _highly_variable_genes_seurat_v3(query, batch_key, n_top_genes, layer, span, max_loess_jitter, batch_key_func)
    106 var_indexer = query.indexer
    108 with futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
--> 109     mvn = MeanVarianceAccumulator(n_batches, n_samples, query.n_vars)
    110     for arrow_tbl in _EagerIterator(query.X(layer).tables(), pool=pool):
    111         data = arrow_tbl["soma_data"].to_numpy()

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/somacore/query/query.py:183, in ExperimentAxisQuery.n_vars(self)
    177 @property
    178 def n_vars(self) -> int:
    179     """The number of ``var`` axis query results.
    180 
    181     Lifecycle: maturing
    182     """
--> 183     return len(self.var_joinids())

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/somacore/query/query.py:167, in ExperimentAxisQuery.var_joinids(self)
    162 def var_joinids(self) -> pa.Array:
    163     """Returns ``var`` ``soma_joinids`` as an Arrow array.
    164 
    165     Lifecycle: maturing
    166     """
--> 167     return self._joinids.var

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/somacore/query/query.py:707, in _JoinIDCache.var(self)
    705 """Join IDs for the var axis. Will load and cache if not already."""
    706 if not self._cached_var:
--> 707     self._cached_var = _load_joinids(
    708         self.owner._var_df, self.owner._matrix_axis_query.var
    709     )
    710 return self._cached_var

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/somacore/query/query.py:718, in _load_joinids(df, axq)
    717 def _load_joinids(df: data.DataFrame, axq: axis.AxisQuery) -> pa.Array:
--> 718     tbl = df.read(
    719         axq.coords,
    720         value_filter=axq.value_filter,
    721         column_names=["soma_joinid"],
    722     ).concat()
    723     return tbl.column("soma_joinid").combine_chunks()

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/tiledbsoma/_dataframe.py:360, in DataFrame.read(***failed resolving arguments***)
    351     query_condition = QueryCondition(value_filter)
    353 sr = self._soma_reader(
    354     schema=schema,  # query_condition needs this
    355     column_names=column_names,
    356     query_condition=query_condition,
    357     result_order=result_order,
    358 )
--> 360 self._set_reader_coords(sr, coords)
    362 # TODO: platform_config
    363 # TODO: batch_size
    365 return TableReadIter(sr)

File /data/yosef3/users/can/conda/envs/scvi_census/lib/python3.10/site-packages/tiledbsoma/_tiledb_array.py:150, in TileDBArray._set_reader_coords(self, sr, coords)
    148 schema = self._handle.schema
    149 if len(coords) > schema.domain.ndim:
--> 150     raise ValueError(
    151         f"coords ({len(coords)} elements) must be shorter than ndim"
    152         f" ({schema.domain.ndim})"
    153     )
    154 for i, coord in enumerate(coords):
    155     dim = self._handle.schema.domain.dim(i)

ValueError: coords (18974 elements) must be shorter than ndim (1)
bkmartinjr commented 8 months ago

@canergen - I believe there is a typo in your query args.

You want:

var_query=soma.AxisQuery(coords=(np.arange(100),))

Ie., the coords param takes a tuple

LMK if this doesn't resolve the issue.

canergen commented 8 months ago

Solved. Thanks