Open ThummeTo opened 1 year ago
I think this is the kind of thing we just want to be working on getting Enzyme ready for.
sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
Why not sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))
here?
Note that without true
,
function fx(dx, x, p, t) dx[:] = re(p)(x) end
This out of place form will be slower than fx(x, p, t) = re(p)(x)
of course because of the scalarizing.
Also one major improvement is to use Lux instead, or for small neural networks use SimpleChains (with static arrays)
Thanks for the reply! Yep ReverseDiffVJP(true)
is a good point, to be honest I wasn't sure if this is allowed to use, because of the "no-branching" requirement for pre-compilation of tapes.
Migration to Lux is also on the to-do-list :-)
And I am super-curious what progress Enzyme is making (after the big steps in the last months/weeks). I will keep checking for that.
Very good news: DtO works in the current release(s) if you specify a solver by hand. Sensitivities are determined correctly and without numerical instabilities/NaNs. Thank you very much @ChrisRackauckas and @frankschae. However the provided MWE as it is (without a solver specified) still fails because of the linked DiffEqBase-issue.
Current progress:
Single event at the same time instant:
Multiple events (multiple zero-crossing event conditions) at the same time instant:
So the only thing remaing is the adjoint sensitivity problem for multiple zero-crossing event conditions. Especially in my application, this is not that important, because solving FMUs backwards in time is not supported by design and causes additional overhead ...
So again, thank you very much!
PS: Are there plans for the last feature for the near future? If not, we could close this issue from my side, but I can offer to open another issue to keep track of that last feature (in case someone searches for it or similar).
We plan to just keep going until everything is supported.
Dear @frankschae,
as promised, I tried to conclude the requirements (in form of MWEs) that are needed to train over arbitrary FMUs using FMI.jl (or probably soon FMISensitivity.jl). Both examples are very simple NeuralODEs, we don't need to train over FMUs for the MWEs.
The requirements / MWEs are (in order of priority, most important first):
ReverseDiff.gradient(...)
andsensealg=ReverseDiffAdjoint()
(this will allow to train fast on any FMU solution)ReverseDiff.gradient(...)
andsensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP())
(this will allow to train even faster on FMU solutions of FMUs that support "checkpointing", namelyfmiXGetFMUState
andfmiXSetFMUState
)Both MWEs run into the problem that the determined gradient contains NaNs, which would lead to NaNs in the parameters and later NaNs during ANN inference.
Some additional info:
Please don't hesitate to involve me if there is anything I can do to support. For example, we could open a PR with tests on basis of the MWEs and/or examples for the documentation. If there is something unclear I can post more information/code or similar.
If we get this working, we have a significant improvement for training ML-models including FMUs (and in general: hybrid ODEs).
Thank you very much & best regards, ThummeTo
PS: "Unfortunately" I am on vacation for the next three weeks :-)
--------------- MWE ---------------
----------- MWE OUTPUT -------------