immunogenomics / starCAT

Implements *CellAnnotator (aka *CAT/starCAT), annotating scRNA-Seq with predefined gene expression programs
MIT License
12 stars 2 forks source link

Issue with dtype of X #1

Closed bharris12 closed 4 months ago

bharris12 commented 4 months ago

Hello, I am trying to run your tool and am running into an error when I run my own data with it.

image

I've tried to cast adata.X as both float32 and int and still get the same error

dylkot commented 4 months ago

Hey @bharris12 thanks for posting this! Can you provide the full set of commands you are running and the full error trace? Can you confirm if the Example vignette is working for you?

bharris12 commented 4 months ago

Thank you for your quick response.

The example vignette fails as well. Here is the full error trace

TypeError                                 Traceback (most recent call last)
Cell In[10], line 2
      1 # Run starCAT to compute the usages and scores for the provided data 
----> 2 usage, scores = tcat.fit_transform(adata)

File [~/.venv/lib/python3.10/site-packages/starcat/starcat.py:290](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/starcat/starcat.py#line=289), in starCAT.fit_transform(self, query, return_unnormalized)
    275 """
    276 Takes an input data matrix and a fixed spectra and uses NNLS to find the optimal
    277 usage matrix. If input data are pandas.DataFrame, returns a DataFrame with row
   (...)
    285 
    286 """
    288 query = self.prepare_query(query)
--> 290 self.usage = self.fit_query_usage(query)
    291 self.usage_norm = self.usage.div(self.usage.sum(axis=1), axis=0)
    293 if len(self.score_data) > 0:

File [~/.venv/lib/python3.10/site-packages/starcat/starcat.py:333](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/starcat/starcat.py#line=332), in starCAT.fit_query_usage(self, query)
    332 def fit_query_usage(self, query):
--> 333     rf_usages = self.refit_usage(query.X, self.ref[self.overlap_genes].values,
    334                      self._nmf_kwargs.copy())          
    335     rf_usages = pd.DataFrame(rf_usages, index=query.obs.index,
    336                              columns=self.ref.index)
    337     return(rf_usages)

File [~/.venv/lib/python3.10/site-packages/starcat/starcat.py:357](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/starcat/starcat.py#line=356), in starCAT.refit_usage(self, X, spectra, nmf_kwargs)
    355 nmf_kwargs['H'] = spectra
    356 nmf_kwargs['n_components'] = spectra.shape[0]
--> 357 _, rf_usages = self._nmf(X, nmf_kwargs=nmf_kwargs)            
    358 return(rf_usages)

File [~/.venv/lib/python3.10/site-packages/cnmf/cnmf.py:564](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/cnmf/cnmf.py#line=563), in cNMF._nmf(self, X, nmf_kwargs)
    553 def _nmf(self, X, nmf_kwargs):
    554     """
    555     Parameters
    556     ----------
   (...)
    562 
    563     """
--> 564     (usages, spectra, niter) = non_negative_factorization(X, **nmf_kwargs)
    566     return(spectra, usages)

File [~/.venv/lib/python3.10/site-packages/sklearn/utils/_param_validation.py:213](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/sklearn/utils/_param_validation.py#line=212), in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    207 try:
    208     with config_context(
    209         skip_parameter_validation=(
    210             prefer_skip_nested_validation or global_skip_validation
    211         )
    212     ):
--> 213         return func(*args, **kwargs)
    214 except InvalidParameterError as e:
    215     # When the function is just a wrapper around an estimator, we allow
    216     # the function to delegate validation to the estimator, but we replace
    217     # the name of the estimator by the name of the function in the error
    218     # message to avoid confusion.
    219     msg = re.sub(
    220         r"parameter of \w+ must be",
    221         f"parameter of {func.__qualname__} must be",
    222         str(e),
    223     )

File [~/.venv/lib/python3.10/site-packages/sklearn/decomposition/_nmf.py:1133](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/sklearn/decomposition/_nmf.py#line=1132), in non_negative_factorization(X, W, H, n_components, init, update_H, solver, beta_loss, tol, max_iter, alpha_W, alpha_H, l1_ratio, random_state, verbose, shuffle)
   1130 X = check_array(X, accept_sparse=("csr", "csc"), dtype=[np.float64, np.float32])
   1132 with config_context(assume_finite=True):
-> 1133     W, H, n_iter = est._fit_transform(X, W=W, H=H, update_H=update_H)
   1135 return W, H, n_iter

File [~/.venv/lib/python3.10/site-packages/sklearn/decomposition/_nmf.py:1730](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/sklearn/decomposition/_nmf.py#line=1729), in NMF._fit_transform(self, X, y, W, H, update_H)
   1723     raise ValueError(
   1724         "When beta_loss <= 0 and X contains zeros, "
   1725         "the solver may diverge. Please add small values "
   1726         "to X, or use a positive beta_loss."
   1727     )
   1729 # initialize or check W and H
-> 1730 W, H = self._check_w_h(X, W, H, update_H)
   1732 # scale the regularization terms
   1733 l1_reg_W, l1_reg_H, l2_reg_W, l2_reg_H = self._compute_regularization(X)

File [~/.venv/lib/python3.10/site-packages/sklearn/decomposition/_nmf.py:1242](http://localhost:8888/lab/tree/~/.venv/lib/python3.10/site-packages/sklearn/decomposition/_nmf.py#line=1241), in _BaseNMF._check_w_h(self, X, W, H, update_H)
   1239     self._n_components = H.shape[0]
   1241 if H.dtype != X.dtype:
-> 1242     raise TypeError(
   1243         "H should have the same dtype as X. Got H.dtype = {}.".format(
   1244             H.dtype
   1245         )
   1246     )
   1248 # 'mu' solver should not be initialized by zeros
   1249 if self.solver == "mu":

TypeError: H should have the same dtype as X. Got H.dtype = float32.
dylkot commented 4 months ago

OK thanks for posting this. It seems like in later versions of scanpy (I confirmd for scanpy 1.10.1) and Python 3.10, the sc.pp.scale preprocessing steps converts the input data to float64 causing the issue. This issue wasn't coming up when we tested in 3.7 and 3.8 so we missed it. But either way, we should be making sure the reference and query are in the same dtype. So I fixed this on the github and in version 1.0.3 on pypi.

https://github.com/immunogenomics/starCAT/blob/main/src/starcat/starcat.py#L328-L330

Would you mind reinstalling it:

pip uninstall starcatpy
pip install starcatpy

and trying one more time? Thanks again for raising this!

bharris12 commented 4 months ago

That worked.

Thank you for all your help and sharing this tool!

dylkot commented 4 months ago

awesome. I hope its useful!