SciML / SciMLSensitivity.jl

A component of the DiffEq ecosystem for enabling sensitivity analysis for scientific machine learning (SciML). Optimize-then-discretize, discretize-then-optimize, adjoint methods, and more for ODEs, SDEs, DDEs, DAEs, etc.
https://docs.sciml.ai/SciMLSensitivity/stable/
Other
330 stars 70 forks source link

Checkpointing flag makes no difference on BacksolveAdjoint gradients #300

Closed samuela closed 4 years ago

samuela commented 4 years ago

I was surprised to find today that the checkpoint argument (docs here) to the BacksolveAdjoint sensealg has no real effect on the gradient output.

Here's a reproduction:

import DifferentialEquations: Tsit5
import DiffEqFlux: FastDense, ODEProblem, solve
import DiffEqSensitivity: BacksolveAdjoint, ODEAdjointProblem
import LinearAlgebra: I
import ControlSystems

# Changing x_dim or T doesn't make any difference :(
x_dim = 2
T = 40.0

A, B, Q, R = -I, I, I, I
cost = (x, u) -> x' * Q * x + u' * R * u
K = ControlSystems.lqr(Matrix{Float64}(A, x_dim, x_dim), B, Q, R)
lqr_params = vcat(-K[:], zeros(x_dim))
policy = FastDense(x_dim, x_dim)

function dynamics!(dx, x, policy_params, t)
    dx .= A * x + B * policy(x, policy_params)
end

function backsolve_grad(x0, policy_params, checkpointing)
    fwd_sol = solve(
        ODEProblem(dynamics!, x0, (0, T), policy_params),
        Tsit5(),
        u0 = x0,
        p = policy_params,
        dense = false,
    )
    bwd_sol = solve(
        ODEAdjointProblem(
            fwd_sol,
            BacksolveAdjoint(checkpointing = checkpointing),
            (x, policy_params, t) -> cost(x, policy(x, policy_params)),
        ),
        Tsit5(),
        dense = false,
        save_everystep = false,
    )

    # We can see that the gradients are exactly equal whereas the recovered x(0)
    # shown in the last x_dim dimensions changes when checkpointing is
    # enabled/disabled.
    @show fwd_sol.u[end]
    @show bwd_sol.u[end]

    # In the backsolve adjoint, the last x_dim dimensions are for the
    # reconstructed x state.
    bwd_sol.u[end][1:end-x_dim]
end

x0 = ones(x_dim)
backsolve_results = backsolve_grad(x0, lqr_params, false)
backsolve_checkpointing_results = backsolve_grad(x0, lqr_params, true)
@assert backsolve_results != backsolve_checkpointing_results

As the example shows, the recovered x(0) values are affected by the checkpointing parameter, but the gradients themselves are not affected. This is a system with relatively large eigenvalues (~1.41421) in the backsolve process, leading to instabilities that should make themselves apparent in the results.

So either the checkpointing is always being done, and it's a failure of disabling it. Or the checkpointing is not being done/not working, and so enabling it has no effect. I'm not sure which one it is, although I'm inclined to guess the latter.

frankschae commented 4 years ago

Did you try to pass some checkpoints explicitly to ODEAdjointProblem? The default is: checkpoints=sol.t (https://github.com/SciML/DiffEqSensitivity.jl/blob/efb94fdc6d9c1b6e08396eaec5c94709ae2da3b8/src/local_sensitivity/backsolve_adjoint.jl#L83) So my first guess would be that checkpoints=sol.t are too many checkpoints to see the effect that you are expecting.

ChrisRackauckas commented 4 years ago

If the forward solution needs to save values, it's already by default going to be automatically using those saveat values as checkpoints (why not?). But I think we found a dead spot in our interface here.

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L39

If you're saving nothing on the forward pass and want to return nothing but the end in the reverse pass, then you don't want to use saveat. The code below munges the solution to kick out the solution you actually asked for. That said, I can see a few safety and interface improvements we should be doing to this code:

  1. If you use dense=false, you don't want the forward solution to not be dense if you're using InterpolatingAdjoint/QuadratueAdjoint, but right now that will be in the splatted kwargs. Should we just use more memory than the user wants? Should we change the default adjoint when dense=false?
  2. In the direct interface (calling adjoints yourself), you can do checkpoints = sol.t. I think you have to do that to kwargs_adj here (https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L114), but that doesn't actually effect what's saved in the forward pass (line 39). So this function has to have the checkpoints=nothing explicitly, then handle the saving of the forward pass, and pass it along to the adjoint problem.
ChrisRackauckas commented 4 years ago

So my first guess would be that checkpoints=sol.t are too many checkpoints to see the effect that you are expecting.

So the issue ends up being the opposite: checkpointing is "turned on" but there are no checkpoints

samuela commented 4 years ago

@ChrisRackauckas My interpretation of the interface was that checkpoints would automatically be taken at fwd_sol.t based on https://github.com/SciML/DiffEqSensitivity.jl/blob/efb94fdc6d9c1b6e08396eaec5c94709ae2da3b8/src/local_sensitivity/backsolve_adjoint.jl#L83 So the user wouldn't need to specify saveat, they would only need to have save_everystep = true set for the forward solution (the default behavior).

I'm not sure that I'm fully understanding what's going on here, but it sounds like the user is required to set saveat points in the forward solve in order to get checkpointing at those points in the backward solve. Is that correct? What's the workaround required for getting the behavior that I'm after?

ChrisRackauckas commented 4 years ago

So the user wouldn't need to specify saveat, they would only need to have save_everystep = true set for the forward solution (the default behavior).

Yes that is correct. I misread your example: it should checkpoint then.

I'm not sure that I'm fully understanding what's going on here, but it sounds like the user is required to set saveat points in the forward solve in order to get checkpointing at those points in the backward solve. Is that correct? What's the workaround required for getting the behavior that I'm after?

No. What I pointed to was just the incorrect thing. That's only used when one is doing Zygote on solve: that's the adjoint definition it uses. In that, saveat is removed from the forward pass when using InterpolatingAdjoint and QuadratureAdjoint so that we have something that's dense, but BacksolveAdjoint is untouched.

But you're not even going there: you're hitting the adjoint problems directly. So in your case, yes as @frankschae said the checkpoints should be coming from fwd_sol.t. They are added as callbacks in the reverse pass:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/backsolve_adjoint.jl#L164

that bump the solution back:

https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/backsolve_adjoint.jl#L210-L222

So I would instrument those callbacks, throw a @show in the condition and stuff, to figure out if that's actually occur as we'd expect. It shouldn't be too difficult to confirm it. @frankschae could you take a look?

ChrisRackauckas commented 4 years ago

I can see that we're actually missing tests on this (@YingboMa ?) so no matter what it might be good to turn this into a backsolve checkpointing test.

samuela commented 4 years ago

But you're not even going there: you're hitting the adjoint problems directly. So in your case, yes as @frankschae said the checkpoints should be coming from fwd_sol.t. They are added as callbacks in the reverse pass:

Ok, so based on this I added save_everystep = checkpointing, to the forward solve, to test that changing fwd_sol.t changes the gradient, but unfortunately I'm still getting equal gradients :/

Here's the complete code:

import DifferentialEquations: Tsit5
import DiffEqFlux: FastDense, ODEProblem, solve
import DiffEqSensitivity: BacksolveAdjoint, ODEAdjointProblem
import LinearAlgebra: I
import ControlSystems

# Changing x_dim or T doesn't make any difference :(
x_dim = 2
T = 40.0

A, B, Q, R = -I, I, I, I
cost = (x, u) -> x' * Q * x + u' * R * u
K = ControlSystems.lqr(Matrix{Float64}(A, x_dim, x_dim), B, Q, R)
lqr_params = vcat(-K[:], zeros(x_dim))
policy = FastDense(x_dim, x_dim, tanh) # also added a nonlinear activation in case this was some kind of aspect of linear ODEs

function dynamics!(dx, x, policy_params, t)
    dx .= A * x + B * policy(x, policy_params)
end

function backsolve_grad(x0, policy_params, checkpointing)
    fwd_sol = solve(
        ODEProblem(dynamics!, x0, (0, T), policy_params),
        Tsit5(),
        u0 = x0,
        p = policy_params,
        dense = false,
        save_everystep = checkpointing,
    )
    bwd_sol = solve(
        ODEAdjointProblem(
            fwd_sol,
            BacksolveAdjoint(checkpointing = checkpointing),
            (x, policy_params, t) -> cost(x, policy(x, policy_params)),
        ),
        Tsit5(),
        dense = false,
        save_everystep = false,
    )

    # We can see that the gradients are exactly equal whereas the recovered x(0)
    # shown in the last x_dim dimensions changes when checkpointing is
    # enabled/disabled.
    @show length(fwd_sol.t)
    @show bwd_sol.u[end]

    # In the backsolve adjoint, the last x_dim dimensions are for the
    # reconstructed x state.
    bwd_sol.u[end][1:end-x_dim]
end

x0 = ones(x_dim)
backsolve_results = backsolve_grad(x0, lqr_params, false)
println()
backsolve_checkpointing_results = backsolve_grad(x0, lqr_params, true)
@assert backsolve_results != backsolve_checkpointing_results

and output:

length(fwd_sol.t) = 2
bwd_sol.u[end] = [-5.109726156563335e12, -5.109726156563335e12, 2.640141719987574, 2.640141719987574, 2.640141719987574, 2.640141719987574, 1.4946737382195063, 1.4946737382195063, 5.109726156564334e12, 5.109726156564334e12]

length(fwd_sol.t) = 34
bwd_sol.u[end] = [-5.109726156563335e12, -5.109726156563335e12, 2.640141719987574, 2.640141719987574, 2.640141719987574, 2.640141719987574, 1.4946737382195063, 1.4946737382195063, 1.0, 1.0]
ERROR: LoadError: AssertionError: backsolve_results != backsolve_checkpointing_results
frankschae commented 4 years ago

I just checked that I can reproduce it. I'll have a closer look tomorrow.

frankschae commented 4 years ago

The issue is that the condition for the callback of the checkpoints evaluates to true only at the very last position when the final time point is reached, i.e., at t = 0.0, see https://github.com/SciML/DiffEqSensitivity.jl/blob/1e8bf48f9cbd38d9aaeb36b57635e03bef7860ff/src/local_sensitivity/backsolve_adjoint.jl#L211 This is because the adaptive Tsit5() solver steps to different time points in forward vs reverse evolution. Therefore, the recovered x values are affected by the checkpointing option but the gradients aren't. [Or in other words in both integrations (checkpointing vs non-checkpointing) exactly the same steps happen except that in the checkpointing one the backward solution is bumped to the forward solution at the initial value.. ]

@ChrisRackauckas Is there a pre-defined Callback alternative from DiffEqCallbacks that we'd like to use here instead?

(checkpoints[idx], t) = (40.0, 39.63464773884082)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (34, 34)
(checkpoints[idx], t) = (40.0, 38.61205003482389)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (33, 34)
(checkpoints[idx], t) = (37.7741245208176, 37.21889445434077)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (33, 34)
(checkpoints[idx], t) = (37.7741245208176, 35.77199158941376)
checkpoints[idx] == t = false
..
(idx, length(checkpoints)) = (7, 34)
(checkpoints[idx], t) = (1.4577124265355914, 1.1905541623730385)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (1, 34)
(checkpoints[idx], t) = (0.0, 0.0)
checkpoints[idx] == t = true
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(34), 0.0)
ChrisRackauckas commented 4 years ago

PresetTimeCallback. A tstop needs to be placed in the reverse pass for this at each of the checkpoint values.

frankschae commented 4 years ago

Looks great:

("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(33), 37.7741245208176)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(32), 35.37543358110203)

("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(4), 0.5027852060961491)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(3), 0.2734669715396772)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(2), 0.08814903521794035)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(1), 0.0)

I prepare the tests and then send a PR.