Closed DhairyaLGandhi closed 4 months ago
@ChrisRackauckas can you assign @AayushSabharwal and I to this issue instead
oh I see, I thought this was the AD of it. If it's just the struct change, @AayushSabharwal should do that for all anonymous functions in the package here.
Example for BatchedInterface
AD:
@variables x(t) = 1.0 y(t) = 1.0 z(t) = 1.0 w(t) = 1.0
@named sys1 = ODESystem([D(x) ~ x + y, D(y) ~ y * z, D(z) ~ z * t * x], t)
sys1 = complete(sys1)
prob1 = ODEProblem(sys1, [], (0.0, 10.0))
@named sys2 = ODESystem([D(x) ~ x + w, D(y) ~ w * t, D(w) ~ x + y + w], t)
sys2 = complete(sys2)
prob2 = ODEProblem(sys2, [], (0.0, 10.0))
bi = BatchedInterface((sys1, [x, y, z]), (sys2, [x, y, w]))
getter = getu(bi)
p1grad, p2grad = Zygote.gradient(prob1, prob2) do prob1, prob2
sum(getter(prob1, prob2))
end
@test p1grad.u0 ≈ ones(3)
testp2grad = zeros(3)
testp2grad[variable_index(prob2, w)] = 1.0
@test p2grad.u0 ≈ testp2grad
getp
is due a major rewrite, so we will need to work out what code paths AD will hit post the changes. As of right now, @AayushSabharwal is aware of the issues. I think we can close this issue if there aren't any other concerns
Is your feature request related to a problem? Please describe.
Currently
getu
returns an anonymous function to calculate the symbol -> index once and cache the results.During AD, we are likely to do something like the following:
Here we lose tracking of gradients back to
sol
for the next step of our analysis. The anonymous function loses this tracking.Describe the solution you’d like
Potential solution could be for
getu
to return a callablestruct
which can be called similarly to how it is carried out right now and have it plumb the gradients back to thisstruct
.Additional context
We need to do reverse-mode AD here, and this needs to also match the interface for the batched interface.