theislab / cellrank

CellRank: dynamics from multi-view single-cell data
https://cellrank.org
BSD 3-Clause "New" or "Revised" License
350 stars 46 forks source link

plot_random_walks csr_matrix error #1213

Closed simozhou closed 3 months ago

simozhou commented 3 months ago

I have instantiated a PseudotimeKernel with params [n=220000, dnorm=False, scheme='hard', frac_to_keep=0.3] with my own data and when I try to run the function plot_random_walks I get the error:

AttributeError: 'csr_matrix' object has no attribute 'A'

the code is super minimal:

pk = cr.kernels.PseudotimeKernel(adata, time_key="MLP_interval_age")
pk.compute_transition_matrix()

pk.plot_random_walks()

and the full error output I get is:

Simulating `100` random walks of maximum length `55000`

  0%|          | 0/100 [00:00<?, ?sim/s]
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[39], line 1
----> 1 g.kernel.plot_random_walks()

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/kernels/_utils.py:184, in require_tmat(wrapped, instance, args, kwargs)
    182 if instance.transition_matrix is None:
    183     raise RuntimeError("Compute transition matrix first as `.compute_transition_matrix()`.")
--> 184 return wrapped(*args, **kwargs)

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/kernels/_base_kernel.py:236, in KernelExpression.plot_random_walks(self, n_sims, max_iter, seed, successive_hits, start_ixs, stop_ixs, basis, cmap, linewidth, linealpha, ixs_legend_loc, n_jobs, backend, show_progress_bar, figsize, dpi, save, **kwargs)
    193 """Plot random walks in an embedding.
    194 
    195 This method simulates random walks on the Markov chain defined though the corresponding transition matrix. The
   (...)
    233 For each random walk, the first/last cell is marked by the start/end colors of ``cmap``.
    234 """
    235 rw = RandomWalk(self.adata, self.transition_matrix, start_ixs=start_ixs, stop_ixs=stop_ixs)
--> 236 sims = rw.simulate_many(
    237     n_sims=n_sims,
    238     max_iter=max_iter,
    239     seed=seed,
    240     n_jobs=n_jobs,
    241     backend=backend,
    242     successive_hits=successive_hits,
    243     show_progress_bar=show_progress_bar,
    244 )
    246 rw.plot(
    247     sims,
    248     basis=basis,
   (...)
    256     **kwargs,
    257 )

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/kernels/utils/_random_walk.py:171, in RandomWalk.simulate_many(self, n_sims, max_iter, seed, successive_hits, n_jobs, backend, show_progress_bar)
    168 max_iter = self._max_iter(max_iter)
    169 start = logg.info(f"Simulating `{n_sims}` random walks of maximum length `{max_iter}`")
--> 171 simss = parallelize(
    172     self._simulate_many,
    173     collection=np.arange(n_sims),
    174     n_jobs=n_jobs,
    175     backend=backend,
    176     show_progress_bar=show_progress_bar,
    177     as_array=False,
    178     unit="sim",
    179 )(max_iter=max_iter, seed=seed, successive_hits=successive_hits)
    180 simss = list(itertools.chain.from_iterable(simss))
    182 logg.info("    Finish", time=start)

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/_utils/_parallelize.py:96, in parallelize.<locals>.wrapper(*args, **kwargs)
     93 else:
     94     pbar, queue, thread = None, None, None
---> 96 res = jl.Parallel(n_jobs=n_jobs, backend=backend)(
     97     jl.delayed(callback)(
     98         *((i, cs) if use_ixs else (cs,)),
     99         *args,
    100         **kwargs,
    101         queue=queue,
    102     )
    103     for i, cs in enumerate(collections)
    104 )
    106 res = np.array(res) if as_array else res
    107 if thread is not None:

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/joblib/parallel.py:1918, in Parallel.__call__(self, iterable)
   1916     output = self._get_sequential_output(iterable)
   1917     next(output)
-> 1918     return output if self.return_generator else list(output)
   1920 # Let's create an ID that uniquely identifies the current call. If the
   1921 # call is interrupted early and that the same instance is immediately
   1922 # re-used, this id will be used to prevent workers that were
   1923 # concurrently finalizing a task from the previous call to run the
   1924 # callback.
   1925 with self._lock:

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/joblib/parallel.py:1847, in Parallel._get_sequential_output(self, iterable)
   1845 self.n_dispatched_batches += 1
   1846 self.n_dispatched_tasks += 1
-> 1847 res = func(*args, **kwargs)
   1848 self.n_completed_tasks += 1
   1849 self.print_progress()

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/kernels/utils/_random_walk.py:127, in RandomWalk._simulate_many(self, sims, max_iter, seed, successive_hits, queue)
    125 res = []
    126 for s in sims:
--> 127     sim = self.simulate_one(
    128         max_iter=max_iter,
    129         seed=None if seed is None else seed + s,
    130         successive_hits=successive_hits,
    131     )
    132     res.append(sim)
    133     if queue is not None:

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/kernels/utils/_random_walk.py:109, in RandomWalk.simulate_one(self, max_iter, seed, successive_hits)
    106 sim, cnt = [ix], -1
    108 for _ in range(max_iter):
--> 109     ix = self._sample(ix, rs=rs)
    110     sim.append(ix)
    111     cnt = (cnt + 1) if self._should_stop(ix) else -1

File /g/furlong/procaccia/miniconda3/envs/cellrank/lib/python3.11/site-packages/cellrank/kernels/utils/_random_walk.py:342, in RandomWalk._sample(self, ix, rs)
    339 def _sample(self, ix: int, *, rs: np.random.RandomState) -> int:
    340     return rs.choice(
    341         self._ixs,
--> 342         p=self._tmat[ix].A.squeeze() if self._is_sparse else self._tmat[ix],
    343     )

AttributeError: 'csr_matrix' object has no attribute 'A'

Versions:

cellrank==2.0.4 scanpy==1.10.2 anndata==0.10.8 numpy==1.26.4 numba==0.60.0 scipy==1.14.0 pandas==2.2.2 pygpcca==1.0.4 scikit-learn==1.1.3 statsmodels==0.14.2 scvelo==0.0.0 pygam==0.9.1 matplotlib==3.9.1 seaborn==0.13.2

Marius1311 commented 3 months ago

Thanks @michalk8! This should be fixed in the linked PR @simozhou