theislab / moscot

Multi-omic single-cell optimal transport tools
https://moscot-tools.org
BSD 3-Clause "New" or "Revised" License
110 stars 9 forks source link

Online method throws error #55

Closed MUCDK closed 2 years ago

MUCDK commented 2 years ago

After installing the requirements as instructed I get the following problem:

Regularized.fit() method does not work if "online" in geometry object is set to True

import scanpy as sc
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
from moscot._solver import Regularized

adata = anndata.read("/home/icb/dominik.klein/git_repos/data/adatas/adata_tedsim_8192.h5ad")
obs_var_time = "depth"
adata_source = adata[adata.obs[obs_var_time] == 11]
adata_target = adata[adata.obs[obs_var_time] == 12]

sc.pp.pca(adata_source)
sc.pp.pca(adata_target)

pointcloud_offline = PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X), online=False)
pointcloud_online = PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X), online=True)

moscot_solver = Regularized(epsilon=0.2)

moscot_solver.fit(pointcloud_offline) # works
moscot_solver.fit(pointcloud_online) # does not work
Error message:
`TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_27778/1474808108.py in <module>
      1 moscot_solver = Regularized(epsilon=0.2)
      2 
----> 3 moscot_solver.fit(pointcloud_online)

/mnt/home/icb/dominik.klein/git_repos/moscot/moscot/_solver.py in fit(self, geom, a, b, **kwargs)
     96         """
     97         geom = self._prepare_geom(geom, **kwargs)
---> 98         self._transport = Transport(geom, a=a, b=b, **self._kwargs)
     99         self._check_marginals(a, b)
    100 

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/tools/transport.py in __init__(self, a, b, *args, **kwargs)
     66       self.geom = pointcloud.PointCloud(*args, **pc_kw)
     67 
---> 68     num_a, num_b = self.geom.shape
     69     self.a = jnp.ones((num_a,)) / num_a if a is None else a
     70     self.b = jnp.ones((num_b,)) / num_b if b is None else b

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
    138   @property
    139   def shape(self):
--> 140     mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
    141     if mat is not None:
    142       return mat.shape

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    110       # If no epsilon was passed on to the geometry, then assume it is one by
    111       # default.
--> 112       cost = -jnp.log(self._kernel_matrix)
    113       return cost if self._epsilon_init is None else self.epsilon * cost
    114     return self._cost_matrix

    [... skipping hidden 15 frame]

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x)
    690 def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
    691   if promote_to_inexact:
--> 692     fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
    693   else:
    694     fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _promote_args_inexact(fun_name, *args)
    600 
    601   Promotes non-inexact types to an inexact type."""
--> 602   _check_arraylike(fun_name, *args)
    603   _check_no_float0s(fun_name, *args)
    604   return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    576                     if not _arraylike(arg))
    577     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 578     raise TypeError(msg.format(fun_name, type(arg), pos))
    579 
    580 def _check_no_float0s(fun_name, *args):

TypeError: log requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.`
Marius1311 commented 2 years ago

on hold.

Marius1311 commented 2 years ago

@michalk8 agreed to look into this a bit.

michalk8 commented 2 years ago

The problem is here: https://github.com/theislab/moscot/blob/dev/moscot/_solver.py#L45 geom.cost_matrix is None when online=True. The Transport object than tries to access self.geometry.shape, which throws the above error. Not sure if not passing both cost/kernel is by design (or passing both and later ignoring kernel matrix), can't find any, except for coyping eps from another temp. geometry to the current one - can be useful and that's what I should've done in the above example. Imho, not really a bug, but a feature. Minimal reproducible example:

from ott.geometry.geometry import Geometry
Geometry().shape
Marius1311 commented 2 years ago

Ok, should we

Marius1311 commented 2 years ago

@MUCDK, can we move this issue to OTT and close it here?

Marius1311 commented 2 years ago

Or just close if it's no longer relevant please.

MUCDK commented 2 years ago

I think we can close it for now, we could think about catching the error ourselves.