Closed michalk8 closed 2 years ago
I'd now argue against this, pinging @MUCDK:
__call__
abstract, which would lose some of its benefits (like the shared code for checking input/output)
we could define an alternative branch in the hierarchy that would be more JIT-friendly, but this would require rewrite of the full OTT backend, an am not sure whether jitting can be achieved with the current stuff (like contextualized solves) can be achievedTaggedArray
to be JAX-compatible; if we were to deploy with OTT backend default (i.e. user has no choice), this would be less of an issueOn the other hand, having our solvers jittable would make implementation of the functions requiring gradients easier (+ also pre-jitting), since we could do:
@jax.jit
def run(x: TaggedArray, y: Optional[TaggedArray], **kwargs: Any) -> float:
solver = SinkhornSolver()
return solver(x, y, **kwargs).cost
On the other hand, this is reinventing the wheel, since this is exactly what OTT does, just using our types... In case "composition over inheritance" we have reduced flexibility (in terms of how do we expose the arguments w.r.t. which we want to differentiate), as well as some redundancy in terms of code:
class SinkhornSolver: # our solver
# could be also abstract, problem really is flexibility, which arguments to expose
# for calculation of
def _value_and_grad(self, x: TaggedArray, y: Optional[TaggedArray], jit: bool = False, **kwargs: Any) -> ...:
# essentially a watered down version of __call__
data = self._prepare_input(x, y, **kwargs)
grad_fn = jax.value_and_grad(lambda data: self._solver(data)) # defer to OTT
if jit:
grad_fn = jax.jit(grad_fn)
return grad_fn(data)
The __call__
abstraction might be necessary in the future because of barycenters, but it would be great to have some initial application from @MUCDK , since this will determine the API.
Happy to discuss also with @zoepiran @giovp @Marius1311.
I don't know enough about pytrees to comment, although I'll try to do some reading to catch up. From high level perspective, would this solution work out of the box with current solvers api @michalk8 ?
class SinkhornSolver: # our solver
# could be also abstract, problem really is flexibility, which arguments to expose
# for calculation of
def _value_and_grad(self, x: TaggedArray, y: Optional[TaggedArray], jit: bool = False, **kwargs: Any) -> ...:
# essentially a watered down version of __call__
data = self._prepare_input(x, y, **kwargs)
grad_fn = jax.value_and_grad(lambda data: self._solver(data)) # defer to OTT
if jit:
grad_fn = jax.jit(grad_fn)
return grad_fn(data)
if yes, I think is fine for us, understand redundancy but not too bad imho. I think important to benchmark to what extent is getting gradients slow (since without jitting) @MUCDK
@giovp this doesn’t work as of now as we have to call value_and_grad on jax.numpy arrays. Also taking the gradient does require more time (not benchmarked yet but empirically) so would only have that as optional.
related to #117 and to #122 (if we decide to jit)