SciML / SymbolicIndexingInterface.jl

A general interface for symbolic indexing of SciML objects used in conjunction with Domain-Specific Languages
https://docs.sciml.ai/SymbolicIndexingInterface/stable/
MIT License
14 stars 6 forks source link

Make `getu` AD friendly #69

Closed DhairyaLGandhi closed 4 months ago

DhairyaLGandhi commented 4 months ago

Is your feature request related to a problem? Please describe.

Currently getu returns an anonymous function to calculate the symbol -> index once and cache the results.

During AD, we are likely to do something like the following:

uf = getu(sol, sys.x)
gradient(sol, prob) do sol, prob
  uf(prob)
end

Here we lose tracking of gradients back to sol for the next step of our analysis. The anonymous function loses this tracking.

Describe the solution you’d like

Potential solution could be for getu to return a callable struct which can be called similarly to how it is carried out right now and have it plumb the gradients back to this struct.

Additional context

We need to do reverse-mode AD here, and this needs to also match the interface for the batched interface.

AnasAbdelR commented 4 months ago

@ChrisRackauckas can you assign @AayushSabharwal and I to this issue instead

ChrisRackauckas commented 4 months ago

oh I see, I thought this was the AD of it. If it's just the struct change, @AayushSabharwal should do that for all anonymous functions in the package here.

AayushSabharwal commented 4 months ago

Example for BatchedInterface AD:

@variables x(t) = 1.0 y(t) = 1.0 z(t) = 1.0 w(t) = 1.0
@named sys1 = ODESystem([D(x) ~ x + y, D(y) ~ y * z, D(z) ~ z * t * x], t)
sys1 = complete(sys1)
prob1 = ODEProblem(sys1, [], (0.0, 10.0))
@named sys2 = ODESystem([D(x) ~ x + w, D(y) ~ w * t, D(w) ~ x + y + w], t)
sys2 = complete(sys2)
prob2 = ODEProblem(sys2, [], (0.0, 10.0))

bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w]))
getter = getu(bi)

p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2
    sum(getter(prob1, prob2))
end

@test p1grad.u0 ≈ ones(3)
testp2grad = zeros(3)
testp2grad[variable_index(prob2, w)] = 1.0
@test p2grad.u0 ≈ testp2grad
DhairyaLGandhi commented 4 months ago

getp is due a major rewrite, so we will need to work out what code paths AD will hit post the changes. As of right now, @AayushSabharwal is aware of the issues. I think we can close this issue if there aren't any other concerns