ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

Was solver and sinkhorn solver handle initializer kwargs inconsistently #591

Open selmanozleyen opened 3 weeks ago

selmanozleyen commented 3 weeks ago

Hi @michalk8

While Sinkhorn and LRGromovWasserstein take kwargs_init and save it GromovWasserstein does not.

Here is a motivation on why we'd want that. For example in moscot we have to handle cases like this. It would be good to have a consistent constructor.

if rank > -1:
    kwargs.setdefault("gamma", 10)
    kwargs.setdefault("gamma_rescale", True)
    eps = kwargs.get("epsilon")
    if eps is not None and eps > 0.0:
        logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
    initializer = "rank2" if initializer is None else initializer
    self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
        rank=rank,
        initializer=initializer,
        kwargs_init=initializer_kwargs,
        **kwargs,
    )
else:
    linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
    if initializer is None:
        initializer = quad_initializers.QuadraticInitializer()
    if isinstance(initializer, str):
        raise ValueError(
            "Expected `initializer` to be an instance of `ott.initializers.quadratic.BaseQuadraticInitializer`,"
            f"found `{initializer}`."
        )
    initializer = functools.partial(initializer, **initializer_kwargs)
    self._solver = gromov_wasserstein.GromovWasserstein(
        linear_solver=linear_solver,
        initializer=initializer,
        **kwargs,
    )

ping @MUCDK