Closed samuela closed 4 years ago
Did you try to pass some checkpoints explicitly to ODEAdjointProblem
?
The default is: checkpoints=sol.t
(https://github.com/SciML/DiffEqSensitivity.jl/blob/efb94fdc6d9c1b6e08396eaec5c94709ae2da3b8/src/local_sensitivity/backsolve_adjoint.jl#L83)
So my first guess would be that checkpoints=sol.t
are too many checkpoints to see the effect that you are expecting.
If the forward solution needs to save values, it's already by default going to be automatically using those saveat
values as checkpoints (why not?). But I think we found a dead spot in our interface here.
If you're saving nothing on the forward pass and want to return nothing but the end in the reverse pass, then you don't want to use saveat
. The code below munges the solution to kick out the solution you actually asked for. That said, I can see a few safety and interface improvements we should be doing to this code:
dense=false
, you don't want the forward solution to not be dense if you're using InterpolatingAdjoint/QuadratueAdjoint, but right now that will be in the splatted kwargs
. Should we just use more memory than the user wants? Should we change the default adjoint when dense=false
?checkpoints = sol.t
. I think you have to do that to kwargs_adj
here (https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/local_sensitivity/concrete_solve.jl#L114), but that doesn't actually effect what's saved in the forward pass (line 39). So this function has to have the checkpoints=nothing
explicitly, then handle the saving of the forward pass, and pass it along to the adjoint problem.So my first guess would be that checkpoints=sol.t are too many checkpoints to see the effect that you are expecting.
So the issue ends up being the opposite: checkpointing is "turned on" but there are no checkpoints
@ChrisRackauckas My interpretation of the interface was that checkpoints would automatically be taken at fwd_sol.t
based on https://github.com/SciML/DiffEqSensitivity.jl/blob/efb94fdc6d9c1b6e08396eaec5c94709ae2da3b8/src/local_sensitivity/backsolve_adjoint.jl#L83
So the user wouldn't need to specify saveat
, they would only need to have save_everystep = true
set for the forward solution (the default behavior).
I'm not sure that I'm fully understanding what's going on here, but it sounds like the user is required to set saveat
points in the forward solve in order to get checkpointing at those points in the backward solve. Is that correct? What's the workaround required for getting the behavior that I'm after?
So the user wouldn't need to specify saveat, they would only need to have save_everystep = true set for the forward solution (the default behavior).
Yes that is correct. I misread your example: it should checkpoint then.
I'm not sure that I'm fully understanding what's going on here, but it sounds like the user is required to set saveat points in the forward solve in order to get checkpointing at those points in the backward solve. Is that correct? What's the workaround required for getting the behavior that I'm after?
No. What I pointed to was just the incorrect thing. That's only used when one is doing Zygote
on solve
: that's the adjoint definition it uses. In that, saveat is removed from the forward pass when using InterpolatingAdjoint and QuadratureAdjoint so that we have something that's dense, but BacksolveAdjoint
is untouched.
But you're not even going there: you're hitting the adjoint problems directly. So in your case, yes as @frankschae said the checkpoints should be coming from fwd_sol.t
. They are added as callbacks in the reverse pass:
that bump the solution back:
So I would instrument those callbacks, throw a @show
in the condition and stuff, to figure out if that's actually occur as we'd expect. It shouldn't be too difficult to confirm it. @frankschae could you take a look?
I can see that we're actually missing tests on this (@YingboMa ?) so no matter what it might be good to turn this into a backsolve checkpointing test.
But you're not even going there: you're hitting the adjoint problems directly. So in your case, yes as @frankschae said the checkpoints should be coming from fwd_sol.t. They are added as callbacks in the reverse pass:
Ok, so based on this I added save_everystep = checkpointing,
to the forward solve, to test that changing fwd_sol.t
changes the gradient, but unfortunately I'm still getting equal gradients :/
Here's the complete code:
import DifferentialEquations: Tsit5
import DiffEqFlux: FastDense, ODEProblem, solve
import DiffEqSensitivity: BacksolveAdjoint, ODEAdjointProblem
import LinearAlgebra: I
import ControlSystems
# Changing x_dim or T doesn't make any difference :(
x_dim = 2
T = 40.0
A, B, Q, R = -I, I, I, I
cost = (x, u) -> x' * Q * x + u' * R * u
K = ControlSystems.lqr(Matrix{Float64}(A, x_dim, x_dim), B, Q, R)
lqr_params = vcat(-K[:], zeros(x_dim))
policy = FastDense(x_dim, x_dim, tanh) # also added a nonlinear activation in case this was some kind of aspect of linear ODEs
function dynamics!(dx, x, policy_params, t)
dx .= A * x + B * policy(x, policy_params)
end
function backsolve_grad(x0, policy_params, checkpointing)
fwd_sol = solve(
ODEProblem(dynamics!, x0, (0, T), policy_params),
Tsit5(),
u0 = x0,
p = policy_params,
dense = false,
save_everystep = checkpointing,
)
bwd_sol = solve(
ODEAdjointProblem(
fwd_sol,
BacksolveAdjoint(checkpointing = checkpointing),
(x, policy_params, t) -> cost(x, policy(x, policy_params)),
),
Tsit5(),
dense = false,
save_everystep = false,
)
# We can see that the gradients are exactly equal whereas the recovered x(0)
# shown in the last x_dim dimensions changes when checkpointing is
# enabled/disabled.
@show length(fwd_sol.t)
@show bwd_sol.u[end]
# In the backsolve adjoint, the last x_dim dimensions are for the
# reconstructed x state.
bwd_sol.u[end][1:end-x_dim]
end
x0 = ones(x_dim)
backsolve_results = backsolve_grad(x0, lqr_params, false)
println()
backsolve_checkpointing_results = backsolve_grad(x0, lqr_params, true)
@assert backsolve_results != backsolve_checkpointing_results
and output:
length(fwd_sol.t) = 2
bwd_sol.u[end] = [-5.109726156563335e12, -5.109726156563335e12, 2.640141719987574, 2.640141719987574, 2.640141719987574, 2.640141719987574, 1.4946737382195063, 1.4946737382195063, 5.109726156564334e12, 5.109726156564334e12]
length(fwd_sol.t) = 34
bwd_sol.u[end] = [-5.109726156563335e12, -5.109726156563335e12, 2.640141719987574, 2.640141719987574, 2.640141719987574, 2.640141719987574, 1.4946737382195063, 1.4946737382195063, 1.0, 1.0]
ERROR: LoadError: AssertionError: backsolve_results != backsolve_checkpointing_results
I just checked that I can reproduce it. I'll have a closer look tomorrow.
The issue is that the condition for the callback of the checkpoints evaluates to true only at the very last position when the final time point is reached, i.e., at t = 0.0, see https://github.com/SciML/DiffEqSensitivity.jl/blob/1e8bf48f9cbd38d9aaeb36b57635e03bef7860ff/src/local_sensitivity/backsolve_adjoint.jl#L211 This is because the adaptive Tsit5() solver steps to different time points in forward vs reverse evolution. Therefore, the recovered x values are affected by the checkpointing option but the gradients aren't. [Or in other words in both integrations (checkpointing vs non-checkpointing) exactly the same steps happen except that in the checkpointing one the backward solution is bumped to the forward solution at the initial value.. ]
@ChrisRackauckas Is there a pre-defined Callback alternative from DiffEqCallbacks that we'd like to use here instead?
(checkpoints[idx], t) = (40.0, 39.63464773884082)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (34, 34)
(checkpoints[idx], t) = (40.0, 38.61205003482389)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (33, 34)
(checkpoints[idx], t) = (37.7741245208176, 37.21889445434077)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (33, 34)
(checkpoints[idx], t) = (37.7741245208176, 35.77199158941376)
checkpoints[idx] == t = false
..
(idx, length(checkpoints)) = (7, 34)
(checkpoints[idx], t) = (1.4577124265355914, 1.1905541623730385)
checkpoints[idx] == t = false
(idx, length(checkpoints)) = (1, 34)
(checkpoints[idx], t) = (0.0, 0.0)
checkpoints[idx] == t = true
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(34), 0.0)
PresetTimeCallback. A tstop needs to be placed in the reverse pass for this at each of the checkpoint values.
Looks great:
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(33), 37.7741245208176)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(32), 35.37543358110203)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(4), 0.5027852060961491)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(3), 0.2734669715396772)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(2), 0.08814903521794035)
("bump solution", cur_time, integrator.t) = ("bump solution", Base.RefValue{Int64}(1), 0.0)
I prepare the tests and then send a PR.
I was surprised to find today that the
checkpoint
argument (docs here) to theBacksolveAdjoint
sensealg has no real effect on the gradient output.Here's a reproduction:
As the example shows, the recovered x(0) values are affected by the
checkpointing
parameter, but the gradients themselves are not affected. This is a system with relatively large eigenvalues (~1.41421) in the backsolve process, leading to instabilities that should make themselves apparent in the results.So either the checkpointing is always being done, and it's a failure of disabling it. Or the checkpointing is not being done/not working, and so enabling it has no effect. I'm not sure which one it is, although I'm inclined to guess the latter.