firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
517 stars 160 forks source link

Adjoint derivative taking a bit long #1621

Open anezkap opened 4 years ago

anezkap commented 4 years ago

Hi, I am solving a two equation problem in Firedrake and I was a bit concerned about the disproportional amount of time it takes to compute the derivative of my reduced functional Jhat.

The 1D problem I'm trying to solve is as follows:

Screenshot 2020-03-05 at 13 05 34

The equation for c and solving method is basically the same as in the DG advection equation with upwinding tutorial, k, k_2 are constants, and the control here is c_in.

Solving for c and q takes only couple of seconds, but computing Jhat.derivative() takes around 6 minutes.

Does this look normal, or is there a problem in my code and/or a way how to solve this faster?

Thank you for your help!

from firedrake import *
from firedrake_adjoint import *

# Set up the mesh
mesh = UnitIntervalMesh(40)

# Set up the function spaces
Vec = VectorFunctionSpace(mesh, "CG", 1)
V_c = FunctionSpace(mesh, "DG", 1)
V_q = FunctionSpace(mesh, "DG", 0)
W = V_c*V_q

# Get the spatial coordinate for x and set constant velocity with static boundary conditions
x, = SpatialCoordinate(mesh)

velocity = as_vector((1, ))
u = Function(Vec).interpolate(velocity)
c_in = Constant(1.0)

bcs = [DirichletBC(W.sub(0), c_in, 1)]

# Set the initial condition
f = Function(W)
with stop_annotating():
    c, q = f.split()
    q.assign(1.0)

# Set time T, step dt
T = 2
dt = T/600
dtc = Constant(dt)

# Set the left hand side of our equation
dc_trial, dq_trial = TrialFunctions(W)
phi, psi = TestFunctions(W)
a = phi*dc_trial*dx + psi*dq_trial*dx

# We define ``n`` to be the built-in ``FacetNormal`` object; a unit normal vector
# that can be used in integrals over exterior and interior facets.  We next define
# ``un`` to be an object which is equal to :math:`\vec{u}\cdot\vec{n}` if this is
# positive, and zero if this is negative. This will be useful in the upwind terms.
n = FacetNormal(mesh)
un = 0.5*(dot(u, n) + abs(dot(u, n)))

k = 0.8
k2 = 0.1

# Right-hand side
L1 = dtc*(c*div(phi*u)*dx
          - conditional(dot(u, n) < 0, phi*dot(u, n)*c_in, 0.0)*ds
          - conditional(dot(u, n) > 0, phi*dot(u, n)*c, 0.0)*ds
          - (phi('+') - phi('-'))*(un('+')*c('+') - un('-')*c('-'))*dS
          - k*phi*q*c*dx
          - k2*psi*q*c*dx)

# Runge-Kutta
f1 = Function(W); f2 = Function(W)
L2 = replace(L1, {c: split(f1)[0], q: split(f1)[1]}); L3 = replace(L1, {c: split(f2)[0], q: split(f2)[1]})

# We now declare a variable to hold the temporary increments at each stage.
df = Function(W)

# We make use of the ``LinearVariationalProblem`` and
# ``LinearVariationalSolver`` objects for each of our Runge-Kutta stages.
params = {'ksp_type': 'preonly', 'pc_type': 'bjacobi', 'sub_pc_type': 'ilu', 'mat_type': 'aij'}
prob1 = LinearVariationalProblem(a, L1, df, bcs=bcs)
solv1 = LinearVariationalSolver(prob1, solver_parameters=params)
prob2 = LinearVariationalProblem(a, L2, df, bcs=bcs)
solv2 = LinearVariationalSolver(prob2, solver_parameters=params)
prob3 = LinearVariationalProblem(a, L3, df, bcs=bcs)
solv3 = LinearVariationalSolver(prob3, solver_parameters=params)

# Run the time loop with three Runge-Kutta stages, and write the results
# into the results list
t = 0.0
step = 0

with stop_annotating():
    c_, q_ = f.split()
    results = [[Function(c_)],[Function(q)]]

while t < T - 0.5*dt:
    solv1.solve()
    f1.assign(f + df)

    solv2.solve()
    f2.assign(0.75*f + 0.25*(f1 + df))

    solv3.solve()
    f.assign((1.0/3.0)*f + (2.0/3.0)*(f2 + df))

    with stop_annotating():
        c_, q_ = f.split()
        results[0].append(Function(c_))
        results[1].append(Function(q_))

    step += 1
    t += dt

# Set up control and reduced functional Jhat
c, q = split(f)
J = assemble(c*ds(2))
m = Control(c_in)
Jhat = ReducedFunctional(J, m)

d = Jhat.derivative()
print(d.dat.data)
salazardetroya commented 4 years ago

It is likely that the the adjoint solver is using an iterative method and taking many iterations. The reason is because the solver parameters from the forward solve are not passed to the adjoint solve. Please see this issue in the pyadjoint repo and use the patch suggested there.

anezkap commented 4 years ago

Thank you for your help. I though that this (taking many iterations without having the solver parameters from the forward solve) might be the problem as well, but unfortunately the suggested patch did not help. But it's okay. I do not really need my code to be super fast at the moment.

florianwechsung commented 4 years ago

Just to check, did you verify that the options are still not passed or did you just observe that the code is still slow? In the latter case, it may be that there is a separate issue. I suspect the problem is that pyadjoint solves all adjoint equations via the solve(...) interface, which has a fair bit of overhead.

anezkap commented 4 years ago

Hi again I'm sorry I left this open for so long. Anyway, turns out that I could actually use my code being a bit faster, even though it is not totally crucial. I only observed that my code is still slow, I did not verify whether the options are passed or not. I'm not totally sure how to do that - could you guide me a bit about how I can verify it, please? Thanks a lot

wence- commented 4 years ago

We recently (today) merged some changes (#1804) that make this kind of use faster. Can you update and check?