DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.33k stars 124 forks source link

Fix complex-valued `t_span` for complex-valued ODEs in odeint #179

Closed gautierronan closed 1 year ago

gautierronan commented 1 year ago

Following up on my PR https://github.com/DiffEqML/torchdyn/pull/178 and on this issue https://github.com/DiffEqML/torchdyn/issues/177.

In the current state of odeint, the time values t are complex-valued if x is initially complex-valued when calling f(t, x) at every time step. Instead, they should be float like. Otherwise, this could be problematic if f(t, x) requires real-valued times (e.g. if f(t, x) = torch.erf(t) * x).

This PR fixes this by making changes to the solver.tableau definitions. The c constants in the tableaus are initialized with a float dtype of the same precision as the (complex or float) dtype of x. This is done by calling t_dtype = getattr(torch, torch.finfo(x.dtype).dtype). If x is float-valued, then t_dtype = x.dtype and things work as usual. Changes in sync_device_dtype were also required to differentiate between the t.dtype and the x.dtype.

Sorry for having missed this in the initial PR !

codecov-commenter commented 1 year ago

Codecov Report

Base: 62.53% // Head: 62.63% // Increases project coverage by +0.09% :tada:

Coverage data is based on head (c068b2b) compared to base (201026a). Patch coverage: 90.00% of modified lines in pull request are covered.

:mega: This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #179 +/- ## ========================================== + Coverage 62.53% 62.63% +0.09% ========================================== Files 27 27 Lines 1879 1884 +5 ========================================== + Hits 1175 1180 +5 Misses 704 704 ``` | [Impacted Files](https://codecov.io/gh/DiffEqML/torchdyn/pull/179?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML) | Coverage Δ | | |---|---|---| | [torchdyn/numerics/solvers/ode.py](https://codecov.io/gh/DiffEqML/torchdyn/pull/179?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML#diff-dG9yY2hkeW4vbnVtZXJpY3Mvc29sdmVycy9vZGUucHk=) | `53.12% <50.00%> (+0.21%)` | :arrow_up: | | [torchdyn/numerics/sensitivity.py](https://codecov.io/gh/DiffEqML/torchdyn/pull/179?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML#diff-dG9yY2hkeW4vbnVtZXJpY3Mvc2Vuc2l0aXZpdHkucHk=) | `97.95% <100.00%> (ø)` | | | [torchdyn/numerics/solvers/\_constants.py](https://codecov.io/gh/DiffEqML/torchdyn/pull/179?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML#diff-dG9yY2hkeW4vbnVtZXJpY3Mvc29sdmVycy9fY29uc3RhbnRzLnB5) | `100.00% <100.00%> (ø)` | | | [torchdyn/numerics/solvers/templates.py](https://codecov.io/gh/DiffEqML/torchdyn/pull/179?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML#diff-dG9yY2hkeW4vbnVtZXJpY3Mvc29sdmVycy90ZW1wbGF0ZXMucHk=) | `75.00% <100.00%> (+0.58%)` | :arrow_up: | Help us with your feedback. Take ten seconds to tell us [how you rate us](https://about.codecov.io/nps?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML). Have a feature suggestion? [Share it here.](https://app.codecov.io/gh/feedback/?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=DiffEqML)

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

gautierronan commented 1 year ago

Actually, thinking more on this, I'm not sure what's the best practice. Should tspan stay complex-valued to keep every object in the model of the same type? Or should it be real-valued for consistency?

I'm unsure of what's best... I'll let the maintainers decide!

Zymrael commented 1 year ago

Thanks for contributing! It makes sense to keep t real-valued. We don't really have any particular use case at the moment where it becomes convenient to keep x and t of the same dtype.