SciML / SciMLBase.jl

The Base interface of the SciML ecosystem
https://docs.sciml.ai/SciMLBase/stable
MIT License
118 stars 91 forks source link

Feat: adjoints through observable functions #689

Closed DhairyaLGandhi closed 1 month ago

DhairyaLGandhi commented 1 month ago

Checklist

Additional context

Currently, ADing through observables errors, however this allows us to AD through the observable function via symbolic indexing and accumulate and return grads against sol

julia> gs3 = gradient(sol) do sol
    sum(sol[sys.w])
end
((u = [[0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]  …  [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0], [0.0, 1.0, 1.0, 1.0]], u_analytic = nothing, errors = nothing, t = nothing, k = nothing, prob = (f = nothing, u0 = nothing, tspan = nothing, p = ([0.0, 2990.0, 0.0],), kwargs = nothing, problem_type = nothing), alg = nothing, interp = nothing, dense = nothing, tslocation = nothing, stats = nothing, alg_choice = nothing, retcode = nothing, resid = nothing, original = nothing),)

This needs handling as part of when the observable symbol is in a collection (vector/ tuple/ ...), and also for various ADs like ReverseDiff and Enzyme.

Add any other context about the problem here.

Ideally, this would be handled by removing all the adjoints related to getindex and let AD do the heavy lifting for us. But this is faster to implement in its current form.

codecov[bot] commented 1 month ago

Codecov Report

Attention: Patch coverage is 0% with 33 lines in your changes are missing coverage. Please review.

Project coverage is 29.16%. Comparing base (a0fab7a) to head (f817b52). Report is 28 commits behind head on master.

Files Patch % Lines
ext/SciMLBaseZygoteExt.jl 0.00% 33 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #689 +/- ## ========================================== - Coverage 31.79% 29.16% -2.64% ========================================== Files 55 55 Lines 4535 4574 +39 ========================================== - Hits 1442 1334 -108 - Misses 3093 3240 +147 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

DhairyaLGandhi commented 1 month ago

Needs https://github.com/JuliaDiff/ChainRules.jl/pull/793

ChrisRackauckas commented 1 month ago

Add your unit tests as a new downstream testset.

ChrisRackauckas commented 1 month ago

https://github.com/SciML/SciMLSensitivity.jl/blob/32f5ae7529a1957661b153f0ca9eff7e4caf0c5a/test/reversediff_output_types.jl#L14 this would hit it.

DhairyaLGandhi commented 1 month ago

Note that with SciMLSensitivity.jl#dg/ss (and https://github.com/SciML/SciMLStructures.jl/pull/18) https://github.com/SciML/SciMLSensitivity.jl/blob/32f5ae7529a1957661b153f0ca9eff7e4caf0c5a/test/reversediff_output_types.jl#L14 looks like:

julia> gs = gradient(u0 -> loss(u0), u0)
([-0.7779831009550049, 0.40028226620020263],)
DhairyaLGandhi commented 1 month ago

I've added a DAE example in the tests, but switched it off until we get SciMLSensitivity updated as well. The DC motor example fails to initialize currently. If there's a different test case, I can also hook that in.

DhairyaLGandhi commented 1 month ago

@ChrisRackauckas SciMLSensitivity test pass with https://github.com/SciML/SciMLBase.jl/pull/689/commits/d061ce47f10d99f78850e08b820727c6228b3c8b (latest commit), but the Core (Downstream) tests get cancelled before anything runs. Is that because the Core (Python) tests fail for unrelated reasons?

gdalle commented 1 month ago

So what happens here is:

DhairyaLGandhi commented 1 month ago

Both CI/ Python and CI/ Downgrade seem to be failing on master as well.

gdalle commented 1 month ago

The problem I mentioned has not been fixed. It's not a problem with ADTypes per se, it's a problem with environment stacking

DhairyaLGandhi commented 1 month ago

Is there anything left to be done in this PR?