Open fverdugo opened 4 years ago
@oriolcg and @fverdugo
I would consider the following steps:
To create a CheckpointedSolution
of a lazy ODESolution
. The obvious case is store all steps, but we could consider other more advanced constructors that accept a function that says whether a step must be stored or not in the future.
The CheckpointedSolution
should iterate as a standard ODESolution
BUT with the only difference that it does not compute steps that have already been checkpointed.
checkp_sol = CheckpointedSolution(sol,strategy) # Optional strategy, default all
Create an Adjoint
of a ODEOperator
. I will work on the details of this implementation. One of the inputs of its constructor is the CheckpointedSolution
.
adj_op = Adjoint(op,checkp_sol)
To be able to reverse the iteration of a CheckpointedSolution
, which requires to compute and store all time steps between CheckpointedSolution in one shot, and eliminate all this data when we move to the next
checkpoint interval (in reverse mode). We could think about a trait Forward
and Reverse
mode for CheckpointedSolution
that would modify the iteration process.
sol_reversed = reverse(checkp_sol)
ODESolution
. Internally, it makes use of the reversed checkpointed solution.
# Iteration of the adjoint backwards in time
for (adj_uh_n, uh_n, t_n) in adj_sol
# here your adjoint sol
# but also the sol'on at t_n
end
This way, we have both the forward solution in the right direction, the adjoint solution in the right direction, and the forward solution at the same time step as the adjoint solution (since it comes for free).
With the DiffEq wrapper, the only thing left I think you need for the adjoints is to make use of p
. Some of the functions won't vjp with Zygote though so it would need to use Array of struct ReverseDiff, but these functions should be compatible with ReverseDiffVJP(true)
.
cc @santiagobadia @oriolcg