theislab / moscot

Multi-omic single-cell optimal transport tools
https://moscot-tools.org
BSD 3-Clause "New" or "Revised" License
112 stars 9 forks source link

Pytrees for our OTT wrappers #124

Closed michalk8 closed 2 years ago

michalk8 commented 2 years ago

related to #117 and to #122 (if we decide to jit)

michalk8 commented 2 years ago

I'd now argue against this, pinging @MUCDK:

On the other hand, having our solvers jittable would make implementation of the functions requiring gradients easier (+ also pre-jitting), since we could do:

@jax.jit
def run(x: TaggedArray, y: Optional[TaggedArray], **kwargs: Any) -> float:
    solver = SinkhornSolver()
    return solver(x, y, **kwargs).cost

On the other hand, this is reinventing the wheel, since this is exactly what OTT does, just using our types... In case "composition over inheritance" we have reduced flexibility (in terms of how do we expose the arguments w.r.t. which we want to differentiate), as well as some redundancy in terms of code:


class SinkhornSolver:  # our solver

    # could be also abstract, problem really is flexibility, which arguments to expose
    # for calculation of
    def _value_and_grad(self, x: TaggedArray, y: Optional[TaggedArray], jit: bool = False, **kwargs: Any) -> ...:
        # essentially a watered down version of __call__
        data = self._prepare_input(x, y, **kwargs)
        grad_fn = jax.value_and_grad(lambda data: self._solver(data))  # defer to OTT
        if jit:
            grad_fn = jax.jit(grad_fn)
        return grad_fn(data)

The __call__ abstraction might be necessary in the future because of barycenters, but it would be great to have some initial application from @MUCDK , since this will determine the API. Happy to discuss also with @zoepiran @giovp @Marius1311.

giovp commented 2 years ago

I don't know enough about pytrees to comment, although I'll try to do some reading to catch up. From high level perspective, would this solution work out of the box with current solvers api @michalk8 ?


class SinkhornSolver:  # our solver

    # could be also abstract, problem really is flexibility, which arguments to expose
    # for calculation of
    def _value_and_grad(self, x: TaggedArray, y: Optional[TaggedArray], jit: bool = False, **kwargs: Any) -> ...:
        # essentially a watered down version of __call__
        data = self._prepare_input(x, y, **kwargs)
        grad_fn = jax.value_and_grad(lambda data: self._solver(data))  # defer to OTT
        if jit:
            grad_fn = jax.jit(grad_fn)
        return grad_fn(data)

if yes, I think is fine for us, understand redundancy but not too bad imho. I think important to benchmark to what extent is getting gradients slow (since without jitting) @MUCDK

MUCDK commented 2 years ago

@giovp this doesn’t work as of now as we have to call value_and_grad on jax.numpy arrays. Also taking the gradient does require more time (not benchmarked yet but empirically) so would only have that as optional.