ott-jax / ott

Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
https://ott-jax.readthedocs.io
Apache License 2.0
524 stars 80 forks source link

Refactor/Fix: WassersteinSolver constructor now throws TypeError when an unrecognized argument is given #579

Closed selmanozleyen closed 1 month ago

selmanozleyen commented 1 month ago

hi,

A user can give an argument by typo or any other misunderstanding and the solver class would work without them noticing. To prevent such cases I made some modifications. I also added tests that asserts that the raises are thrown properly.

Note: I am not sure about why the linting fails, it tox -e lint-code passes locally for me. Note: I also modified the caching in CI's because it didn't work on my pr for some reason

Related: https://github.com/theislab/moscot/pull/748

ping: @MUCDK

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 57.14286% with 3 lines in your changes missing coverage. Please review.

Project coverage is 87.81%. Comparing base (aa33bd9) to head (88bde47). Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/ott/solvers/was_solver.py 57.14% 2 Missing and 1 partial :warning:
Additional details and impacted files [![Impacted file tree graph](https://app.codecov.io/gh/ott-jax/ott/pull/579/graphs/tree.svg?width=650&height=150&src=pr&token=14PUIHGLV9&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax)](https://app.codecov.io/gh/ott-jax/ott/pull/579?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax) ```diff @@ Coverage Diff @@ ## main #579 +/- ## ========================================== - Coverage 87.83% 87.81% -0.03% ========================================== Files 73 73 Lines 7826 7845 +19 Branches 1127 1133 +6 ========================================== + Hits 6874 6889 +15 - Misses 799 801 +2 - Partials 153 155 +2 ``` | [Files with missing lines](https://app.codecov.io/gh/ott-jax/ott/pull/579?dropdown=coverage&src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax) | Coverage Δ | | |---|---|---| | [src/ott/solvers/was\_solver.py](https://app.codecov.io/gh/ott-jax/ott/pull/579?src=pr&el=tree&filepath=src%2Fott%2Fsolvers%2Fwas_solver.py&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax#diff-c3JjL290dC9zb2x2ZXJzL3dhc19zb2x2ZXIucHk=) | `79.59% <57.14%> (-3.75%)` | :arrow_down: | ... and [4 files with indirect coverage changes](https://app.codecov.io/gh/ott-jax/ott/pull/579/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=ott-jax)
marcocuturi commented 1 month ago

thanks @selmanozleyen for the PR! i will defer to @michalk8 on this, but it feels that if we implement this for this particular solver, we would need to implement it for all solvers, no? What was the use case that revealed the problem?

selmanozleyen commented 1 month ago

For linear solvers there is no need as their base class Sinkhorn doesn't take kwargs. Since WassersteinSolver now handles unrecognized kwargs, all it's child classes will also handle it (since all child classes pass remaining kwargs to super()__init__()).

In moscot we don't want to ignore any unrecognized arguments since there are many arguments, and with some typo etc. it can lead to some well hidden bugs.

Here is the PR for it:https://github.com/theislab/moscot/pull/748

We use many (if not all) solvers in our case and from my tests this PR should be enough to cover the constructors for linear and quadratic solvers. I am not sure about other methods such as solve in ottjax though.

selmanozleyen commented 1 month ago

@michalk8 since the interface is going to change I think it would be better if you did it. I already resolved other pre-commit and formatting issues you mentioned

michalk8 commented 1 month ago

@michalk8 since the interface is going to change I think it would be better if you did it. I already resolved other pre-commit and formatting issues you mentioned

Ok, thanks! I will then close this PR and open tomorrow a new one.

selmanozleyen commented 1 month ago

hi @michalk8, just wanted to remind you on this. I think many test cases and stuff might have to change since the API also changes. So maybe I can help a bit