theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
347 stars 46 forks source link

`gene_trends` fails for sklearn model #284

Closed cdedonno closed 4 years ago

cdedonno commented 4 years ago

I'm trying to plot gene expression using

from sklearn.svm import SVR
model = cr.ul.models.SKLearnModel(adata_oligo_wt, model=SVR)  
cr.pl.gene_trends(adata_oligo_wt, model=model, data_key='X',
                  genes=['Sgk1'], 
                  time_key='dpt_pseudotime')

The error I get is

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-19838d1010b5> in <module>
      3 cr.pl.gene_trends(adata_oligo_wt, model=model, data_key='X',
      4                   genes=['Sgk1'],
----> 5                   time_key='dpt_pseudotime')

~/anaconda3/envs/sc/lib/python3.7/site-packages/cellrank/plotting/_gene_trend.py in gene_trends(adata, model, genes, lineages, data_key, final, start_lineage, end_lineage, conf_int, same_plot, hide_cells, perc, lineage_cmap, abs_prob_cmap, cell_color, color, cell_alpha, lineage_alpha, size, lw, show_cbar, margins, sharey, figsize, dpi, ncols, n_jobs, backend, ext, suptitle, save, dirname, plot_kwargs, show_progres_bar, **kwargs)
    270         extractor=lambda modelss: {k: v for m in modelss for k, v in m.items()},
    271         show_progress_bar=show_progres_bar,
--> 272     )(lineages, start_lineage, end_lineage, **kwargs)
    273     logg.info("    Finish", time=start)
    274 

~/anaconda3/envs/sc/lib/python3.7/site-packages/cellrank/utils/_parallelize.py in wrapper(*args, **kwargs)
    117                 *((i, cs) if use_ixs else (cs,)), *args, **kwargs, queue=queue
    118             )
--> 119             for i, cs in enumerate(collections)
    120         )
    121 

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/parallel.py in __call__(self, iterable)
   1027             # remaining jobs.
   1028             self._iterating = False
-> 1029             if self.dispatch_one_batch(iterator):
   1030                 self._iterating = self._original_iterator is not None
   1031 

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/parallel.py in dispatch_one_batch(self, iterator)
    845                 return False
    846             else:
--> 847                 self._dispatch(tasks)
    848                 return True
    849 

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/parallel.py in _dispatch(self, batch)
    763         with self._lock:
    764             job_idx = len(self._jobs)
--> 765             job = self._backend.apply_async(batch, callback=cb)
    766             # A job can complete so quickly than its callback is
    767             # called before we get here, causing self._jobs to

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/_parallel_backends.py in apply_async(self, func, callback)
    206     def apply_async(self, func, callback=None):
    207         """Schedule a func to be run"""
--> 208         result = ImmediateResult(func)
    209         if callback:
    210             callback(result)

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/_parallel_backends.py in __init__(self, batch)
    570         # Don't delay the application, to avoid keeping the input
    571         # arguments in memory
--> 572         self.results = batch()
    573 
    574     def get(self):

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/parallel.py in __call__(self)
    251         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    252             return [func(*args, **kwargs)
--> 253                     for func, args, kwargs in self.items]
    254 
    255     def __reduce__(self):

~/anaconda3/envs/sc/lib/python3.7/site-packages/joblib/parallel.py in <listcomp>(.0)
    251         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    252             return [func(*args, **kwargs)
--> 253                     for func, args, kwargs in self.items]
    254 
    255     def __reduce__(self):

~/anaconda3/envs/sc/lib/python3.7/site-packages/cellrank/plotting/_utils.py in _fit(genes, lineage_names, start_lineages, end_lineages, queue, **kwargs)
    424             model = (
    425                 models[gene][ln]
--> 426                 .prepare(gene, ln, start_lineage=sc, end_lineage=ec, **kwargs)
    427                 .fit()
    428             )

~/anaconda3/envs/sc/lib/python3.7/site-packages/cellrank/utils/models/_models.py in fit(self, x, y, w, **kwargs)
    733 
    734         fit_fn = getattr(self.model, self._fit_name)
--> 735         self._model = fit_fn(self.x, self.y, **kwargs)
    736 
    737         return self

TypeError: fit() missing 1 required positional argument: 'y'

I made sure my lineage probs have been written to .obsm and the time key also exists in .obs.

Versions:

cellrank==1.0.0-rc.0 scanpy==1.5.1 anndata==0.7.3 numpy==1.18.5 scipy==1.5.0 pandas==1.0.5 scikit-learn==0.23.1 statsmodels==0.11.1 python-igraph==0.8.2 scvelo==0.2.1

...

michalk8 commented 4 years ago

Hi @cdedonno , you must pass the instance of the model, not the class (however, some sklearn-like models are being fitten during initialization, that's why it doesn't complain). Passing an instance should work as intended.