SciML / SciMLBase.jl

The Base interface of the SciML ecosystem
https://docs.sciml.ai/SciMLBase/stable
MIT License
118 stars 91 forks source link

fix: Correct gradients for vector of symbols while indexing #678

Closed DhairyaLGandhi closed 2 months ago

DhairyaLGandhi commented 2 months ago

Checklist

Additional context

Before this PR, indexing with vector of symbols fails at the check for symbolic variables at https://github.com/SciML/SciMLBase.jl/blob/9d87ca04a15f6f94a2f3362f91a95a66950003ae/ext/SciMLBaseZygoteExt.jl#L111 and returns

@parameters σ ρ β
@variables x(t) y(t) z(t) w(t)

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x * (ρ - z) - y,
    D(z) ~ x * y - β * z,
    w ~ x + y + z]

@mtkbuild sys = ODESystem(eqs, t)

u0 = [D(x) => 2.0,
    x => 1.0,
    y => 0.0,
    z => 0.0]

p = [σ => 28.0,
    ρ => 10.0,
    β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())
julia> gs2 = Zygote.gradient(sol) do sol
    sum(sum.(sol[[sys.x, sys.y]]))
end
(VectorOfArray{Float64, 2, Vector{Vector{Float64}}}([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]  …  [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),)

This PR:

julia> gs2 = Zygote.gradient(sol) do sol
    sum(sum.(sol[[sys.x, sys.y]]))
end
(VectorOfArray{Float64, 2, Vector{Vector{Float64}}}([[0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]  …  [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]),)

Requires https://github.com/SciML/RecursiveArrayTools.jl/pull/367

codecov[bot] commented 2 months ago

Codecov Report

Attention: Patch coverage is 0% with 20 lines in your changes are missing coverage. Please review.

Project coverage is 31.64%. Comparing base (1238b2b) to head (071ad2f). Report is 11 commits behind head on master.

Files Patch % Lines
ext/SciMLBaseZygoteExt.jl 0.00% 18 Missing :warning:
ext/SciMLBaseChainRulesCoreExt.jl 0.00% 2 Missing :warning:
Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #678 +/- ## ========================================== - Coverage 31.72% 31.64% -0.08% ========================================== Files 55 55 Lines 4505 4519 +14 ========================================== + Hits 1429 1430 +1 - Misses 3076 3089 +13 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

ChrisRackauckas commented 2 months ago

Add a test?

DhairyaLGandhi commented 2 months ago

Yes, I wanted to open this to gauge feedback and see breakages downstream.

ChrisRackauckas commented 2 months ago

No breakages downstream.

DhairyaLGandhi commented 2 months ago

There seem to be a lot of failures because of some rate limits being hit with codecov

ChrisRackauckas commented 2 months ago

Downstream currently has a failure from the RealInput guesses change. Can you confirm the relevant downstream tests are passing locally for you?

DhairyaLGandhi commented 2 months ago

I do see a failure in a clean env, checking

ERROR: LoadError: MethodError: no method matching _getindex(::ODESolution{…}, ::NotSymbolic, ::Vector{…})

Closest candidates are:
  _getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::Colon...) where {T<:Number, N}
   @ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:307
  _getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::Colon...) where {T, N}
   @ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:296
  _getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::AbstractArray{Bool}, Colon...) where {T, N}
   @ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:313
  ...
DhairyaLGandhi commented 2 months ago

That was a typo, and in the process I also figured that some of the rrules in SciMLBaseChainRulesExt were using old SII syntax, so I updated those too.

The added tests pass for me locally.

AayushSabharwal commented 2 months ago

This seems to also "just work" for getu:

julia> gxy = getu(prob, [sys.x, sys.y])
julia> gradsii = Zygote.gradient(sol) do sol
           sum(sum.(gxy(sol)))
       end
julia> gradidx = Zygote.gradient(sol) do sol
           sum(sum.(sol[[sys.x, sys.y]]))
       end
julia> gradidx[1] == gradsii[1]
true
AayushSabharwal commented 2 months ago

Might be because I'm on https://github.com/SciML/SymbolicIndexingInterface.jl/pull/72

DhairyaLGandhi commented 2 months ago

"Just works" for getu makes sense with these changes. Thanks for checking as well!

AnasAbdelR commented 2 months ago

@ChrisRackauckas thoughts on merging?