gridap / GridapODEs.jl

Time stepping for Gridap
MIT License
33 stars 7 forks source link

Discussion on time-dependent adjoints #28

Open fverdugo opened 4 years ago

fverdugo commented 4 years ago

sol_lazy =solve(solver,op)

# We need a forward solution that can be iterated in reverse
sol = collect(sol_lazy) # Option 1
sol = CeckpointedSol(sol_lazy) # Option 2
sol_reversed = reverse(sol)

for (uh_n,t_n) in sol_reversed

end

#  Adjoint (from the sol that can be iterated in reverse)
adj_op = Adjoint(op,sol,j_u(sol))

# Lazy adjoint solution
adj_sol = solve(solve,adj_op)

# Iteration of the adjoint backwards in time
for (adj_uh_n, t_n) in adj_sol

end

# If needed, reverse iteration of forward and adjoint solutions
for ((adj_uh_n, t_n), (uh_n, t_n) )in zip(adj_sol,sol_reversed)

end

cc @santiagobadia @oriolcg

santiagobadia commented 4 years ago

@oriolcg and @fverdugo

I would consider the following steps:

  1. To create a CheckpointedSolution of a lazy ODESolution. The obvious case is store all steps, but we could consider other more advanced constructors that accept a function that says whether a step must be stored or not in the future.

  2. The CheckpointedSolution should iterate as a standard ODESolution BUT with the only difference that it does not compute steps that have already been checkpointed.

    checkp_sol = CheckpointedSolution(sol,strategy) # Optional strategy, default all
  3. Create an Adjoint of a ODEOperator. I will work on the details of this implementation. One of the inputs of its constructor is the CheckpointedSolution.

    adj_op = Adjoint(op,checkp_sol)
  4. To be able to reverse the iteration of a CheckpointedSolution, which requires to compute and store all time steps between CheckpointedSolution in one shot, and eliminate all this data when we move to the next checkpoint interval (in reverse mode). We could think about a trait Forward and Reverse mode for CheckpointedSolution that would modify the iteration process.

sol_reversed = reverse(checkp_sol)
  1. To create a lazy adjoint solution, that is iterated backward in time, but analogous to the forward ODESolution. Internally, it makes use of the reversed checkpointed solution.
    # Iteration of the adjoint backwards in time
    for (adj_uh_n, uh_n, t_n) in adj_sol
    # here your adjoint sol
    # but also the sol'on at t_n
    end

This way, we have both the forward solution in the right direction, the adjoint solution in the right direction, and the forward solution at the same time step as the adjoint solution (since it comes for free).

ChrisRackauckas commented 3 years ago

With the DiffEq wrapper, the only thing left I think you need for the adjoints is to make use of p. Some of the functions won't vjp with Zygote though so it would need to use Array of struct ReverseDiff, but these functions should be compatible with ReverseDiffVJP(true).