SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
328 stars 71 forks source link

adjoint_sensitivities gives incorrect gradients when g depends on the parameters #302

Closed samuela closed 2 years ago

samuela commented 4 years ago

I've found that the gradients from adjoint_sensitivities are incorrect when the g function depends on the parameters p.

Consider using an ODEProblem to describe the state of a continuous time, infinite-horizon linear system controlled by the LQR controller. Because the LQR policy is optimal, the gradient of this integral wrt the parameters should be 0. We should at least get numerically close to this when the timespan is long enough and the solver is sufficiently accurate.

Here's some code to reproduce this situation:

import DifferentialEquations: Vern9
import DiffEqFlux: FastDense, initial_params, ODEProblem, solve
import DiffEqSensitivity:
    InterpolatingAdjoint, ODEAdjointProblem, adjoint_sensitivities
import LinearAlgebra: I
import ControlSystems

x_dim = 2
A, B, Q, R = zeros(x_dim, x_dim), I, I, I
dynamics = (x, u) -> A * x + B * u
cost = (x, u) -> x' * Q * x + u' * R * u
x0 = ones(x_dim)

K = ControlSystems.lqr(A, B, Q, R)
lqr_params = vcat(-K[:], zeros(x_dim))
policy = FastDense(x_dim, x_dim)

function f!(dx, x, policy_params, t)
    u = policy(x, policy_params)
    dx .= dynamics(x, u)
end

fwd_sol = solve(
    ODEProblem(f!, x0, (0, 100.0), lqr_params),
    Vern9(),
    u0 = x0,
    p = lqr_params,
)
_, g = adjoint_sensitivities(
    fwd_sol,
    Vern9(),
    (x, p, t) -> cost(x, policy(x, p)),
    nothing,
)
@assert isapprox(g, zeros(6)')

# Doing it this way doesn't work either
# bwd_sol = solve(
#     ODEAdjointProblem(
#         fwd_sol,
#         InterpolatingAdjoint(),
#         (x, p, t) -> cost(x, policy(x, p)),
#     ),
#     Vern9(),
#     dense = false,
#     save_everystep = false,
# )

I have a hunch that this is due to the adjoint calculation missing the g_p(t) term in the integral in the dG/dp=... equation here. Although I'm not passing in dg, the docs here suggest that this should not be necessary and the appropriate gradients will be calculated automatically.

For reference, https://arxiv.org/pdf/2006.09178.pdf offers a treatment of how these gradients should behave, see esp. Proposition 3.4.

samuela commented 4 years ago

My hunch is further reinforced by the fact that this version works as expected:

import DifferentialEquations: Vern9
import DiffEqFlux: FastDense, initial_params, ODEProblem, solve
import DiffEqSensitivity:
    InterpolatingAdjoint, ODEAdjointProblem, adjoint_sensitivities
import LinearAlgebra: I
import ControlSystems
import Zygote

x_dim = 2
A, B, Q, R = zeros(x_dim, x_dim), I, I, I
dynamics = (x, u) -> A * x + B * u
cost = (x, u) -> x' * Q * x + u' * R * u
x0 = ones(x_dim)

K = ControlSystems.lqr(A, B, Q, R)
lqr_params = vcat(-K[:], zeros(x_dim))
policy = FastDense(x_dim, x_dim)

function f(z, p, t)
    x = @view z[2:end]
    u = policy(x, p)
    vcat(cost(x, u), dynamics(x, u))
end

function loss(policy_params)
    z0 = vcat(0.0, x0)
    sol = solve(
        ODEProblem(f, z0, (0, 100.0), policy_params),
        Vern9(),
        u0 = z0,
        p = policy_params,
    )
    Array(sol)[1, end]
end

g, = Zygote.gradient(loss, lqr_params)
@assert isapprox(g, zeros(6), atol = 1e-3)

Unfortunately this workaround is not applicable to my use case, as I'm really after the number of DEStats of the backward solve.

samuela commented 4 years ago

Even when providing a dg it doesn't seem to have any effect. The @assert false in this script never fires:

import DifferentialEquations: Vern9
import DiffEqFlux: FastDense, initial_params, ODEProblem, solve
import DiffEqSensitivity:
    InterpolatingAdjoint, ODEAdjointProblem, adjoint_sensitivities
import LinearAlgebra: I
import ControlSystems

x_dim = 2
A, B, Q, R = zeros(x_dim, x_dim), I, I, I
dynamics = (x, u) -> A * x + B * u
cost = (x, u) -> x' * Q * x + u' * R * u
x0 = ones(x_dim)

K = ControlSystems.lqr(A, B, Q, R)
lqr_params = vcat(-K[:], zeros(x_dim))
policy = FastDense(x_dim, x_dim)

function f!(dx, x, policy_params, t)
    u = policy(x, policy_params)
    dx .= dynamics(x, u)
end

fwd_sol = solve(
    ODEProblem(f!, x0, (0, 100.0), lqr_params),
    Vern9(),
    u0 = x0,
    p = lqr_params,
)
g = (x, p, t) -> cost(x, policy(x, p))
_, grad = adjoint_sensitivities(
    fwd_sol,
    Vern9(),
    g,
    nothing,
    dg = (out, u, p, t) -> begin
        @assert false
        ḡ, = Zygote.gradient(g, u, p, t)
        out .= ḡ
    end,
)
@assert isapprox(grad, zeros(6)')
samuela commented 4 years ago

I'm having a hard time following some of the code, but it looks as though some of the logic here https://github.com/SciML/DiffEqSensitivity.jl/blob/45870d2e45d74789f896d6a7e43517f7e8869323/src/local_sensitivity/derivative_wrappers.jl#L471 is not firing correctly.

ChrisRackauckas commented 4 years ago

This is a duplicate of https://github.com/SciML/DiffEqSensitivity.jl/issues/286 , though it would be nice to finally clean that up. As that other issue mentions, which doesn't show up when doing autodiff of solve because the AD implementation naturally treats g as the identity function and then differentiates the next part, so it just never comes up, and we already fixed it for continuous cost functions. It's just the case of dg/dp on discrete costs.

@YingboMa would you mind taking a look and seeing if we can close this for good? Even if it's just requiring that someone give the tuple with dg/dp when it's non-zero and not setting up the autodiff, that would be good enough I think.

samuela commented 4 years ago

It's just the case of dg/dp on discrete costs.

Although in this case I'm using a continuous cost function, no?

ChrisRackauckas commented 4 years ago

Oh, the latest docs address this: https://docs.sciml.ai/latest/analysis/sensitivity/#Syntax-1

samuela commented 4 years ago

Oh, the latest docs address this: https://docs.sciml.ai/latest/analysis/sensitivity/#Syntax-1

Thanks for the update! Based on the latest docs, I found that this works:

_, grad = adjoint_sensitivities(
    fwd_sol,
    Vern9(),
    g,
    nothing,
    (
        (out, u, p, t) -> begin
            ū, _, _ = Zygote.gradient(g, u, p, t)
            out .= ū
        end,
        (out, u, p, t) -> begin
            _, p̄, _ = Zygote.gradient(g, u, p, t)
            out .= p̄
        end,
    ),
)

but this

_, grad = adjoint_sensitivities(
    fwd_sol,
    Vern9(),
    g,
    nothing,
    dg = (
        (out, u, p, t) -> begin
            ū, _, _ = Zygote.gradient(g, u, p, t)
            out .= ū
        end,
        (out, u, p, t) -> begin
            _, p̄, _ = Zygote.gradient(g, u, p, t)
            out .= p̄
        end,
    ),
)

does not.

It also strikes me as a bit odd that manual gradients are required for g, while the rest of the DifferentialEquations.jl ecosystem seems to do AD on dynamics functions automatically. Overall it feels like there are a lot of footguns lying around, and potential for silent failures.

ChrisRackauckas commented 4 years ago

It also strikes me as a bit odd that manual gradients are required for g, while the rest of the DifferentialEquations.jl ecosystem seems to do AD on dynamics functions automatically.

We do need to clean this part up. It wasn't part of the original interface, so it was only recently tagged on. We need to add autodiff handling for this, I agree that would be nice.

Overall it feels like there are a lot of footguns lying around, and potential for silent failures.

Are there any more I should be worried about? This one should go away once we add the AD there, which is just something @YingboMa never got around to. Any other footguns should be explicit failures? The only other one that I know about in the DiffEq-verse is the silent kwargs issue, where abtol=1e-5 can silently do nothing because of the kwarg handling, but that's surprisingly difficult to fix (I plan to tackle it with @kanav99 sometime before the end of the year though). If there are any others, please let us know.

samuela commented 4 years ago

Are there any more I should be worried about? This one should go away once we add the AD there, which is just something @YingboMa never got around to. Any other footguns should be explicit failures? The only other one that I know about in the DiffEq-verse is the silent kwargs issue, where abtol=1e-5 can silently do nothing because of the kwarg handling, but that's surprisingly difficult to fix (I plan to tackle it with @kanav99 sometime before the end of the year though). If there are any others, please let us know.

Yeah, I'll definitely keep an eye out in the future! Happy to hear that you're receptive to hearing about these things. In this particular case I was thinking of just the hiccups encountered in this issue:

I know software is never perfect, so I appreciate your willingness to make changes and address usability issues!

samuela commented 4 years ago

I ran a little experiment (code and results) to measure the performance impact of adjoint_sensitivies(sol, solver, g) vs adjoint_sensitivies(sol, solver, g, nothing, (dgdu, dgdp)). It turns out that including (dgdu, dgdp) introduces a ~50x slowdown. This isn't super high-priority for me, since my problems are relatively small, but I figured it would be something to note.

ChrisRackauckas commented 4 years ago

Thanks, I think @YingboMa said he was going to take a week to go through some of these issues and really hammer this piece home, so I'll leave it to him.

I know software is never perfect, so I appreciate your willingness to make changes and address usability issues!

Yes no worries, it's always about growing the test suite. A lot of cases are covered (https://github.com/SciML/DiffEqSensitivity.jl/tree/master/test/local_sensitivity), but we have much better coverage on the AD portions (what was concrete_solve) than the direct adjoint usage, and in those areas g is always identity. I'm sure there's still work to do there but the Zygote.gradient usage should be quite robust by now. Direct usage, especially continuous, just has had fewer people banging away at it, so the combinatorics of checking each combination hasn't naturally been done in the wild yet, so what we have tested is what we a priori thought could be issues. What you're seeing with that dgdp term is that it was only found missing within the last month (https://github.com/SciML/DiffEqSensitivity.jl/pull/285) and we've been trying to correct that issue ASAP, but it's still mid correction and the AD automated part is missing.

So the interface is a bit wonky and some automation is missing, but hopefully that will be all up and running by the end of the month, in which case we should then think about whether we should do a breaking interface change to make things more clear. We just didn't plan for this 🤷

I ran a little experiment (code and results) to measure the performance impact of adjoint_sensitivies(sol, solver, g) vs adjoint_sensitivies(sol, solver, g, nothing, (dgdu, dgdp)). It turns out that including (dgdu, dgdp) introduces a ~50x slowdown. This isn't super high-priority for me, since my problems are relatively small, but I figured it would be something to note.

That is something to note, but it might be more calculations just because the integrand now (properly) changes, making it take more values to converge. That will need a deep dive.

ChrisRackauckas commented 2 years ago

This is all properly documented and such in the v7 release, so I'm closing it.

https://sensitivity.sciml.ai/dev/manual/direct_adjoint_sensitivities/#SciMLSensitivity.adjoint_sensitivities

Though @frankschae I do think it would be more clear in the docs if it was dgdu_continuous and dgdp_continuous. Thoughts?

ChrisRackauckas commented 2 years ago

https://github.com/SciML/SciMLSensitivity.jl/pull/675 puts a nice cap on this, since there is now

                                     dgdu_discrete::DG1 = nothing,
                                     dgdp_discrete::DG2 = nothing,
                                     dgdu_continuous::DG3 = nothing,
                                     dgdp_continuous::DG4 = nothing,

making each term explicit, no more overloading of the term dg with tuples and all.