firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
482 stars 157 forks source link

BUG: pyadjiont.ReducedFunctional.__call__ doesn't notice mesh.coordinates changes #3589

Open APaganini opened 1 month ago

APaganini commented 1 month ago

Describe the bug We can use firedrake/pyadjoint to compute shape derivatives. When this is done on a mesh mesh_m created from a vector function T, calling ReducedFunctional.__call__ resets the value of T

Steps to Reproduce

from firedrake import *
from firedrake.adjoint import *

# reference mesh
mesh_r = UnitSquareMesh(5,5)
V = VectorFunctionSpace(mesh_r, "CG", 1)
X = SpatialCoordinate(mesh_r)
T = Function(V).interpolate(X)

# create tape for shape derivatives
continue_annotation()
mesh_m = Mesh(T)
W = VectorFunctionSpace(mesh_m, "CG", 1)
T_m = Function(W)
mesh_m.coordinates.assign(mesh_m.coordinates + T_m)
J = assemble(1*dx(domain=mesh_m))
Jred = ReducedFunctional(J, Control(T_m))
stop_annotating()

# failing tests
T *= 2
print("Norm of T: ", norm(T))
print("Expanded area: ", assemble(1*dx(domain=mesh_m)))
# the following line does not notice that mesh_m.coordinates
# has changed, and even worse, it resets T and mesh_m (I don't
# know in which order) to their original values
print("(pyadjiont) Expanded area: ", Jred.__call__(T_m))
print("Expanded area: ", assemble(1*dx(domain=mesh_m)))
print("Norm of T: ", norm(T))

Expected behavior Calling ReducedFunctional.__call__ should not change T, and it'd be nice if ReducedFunctional.__call__ noticed changes in mesh_m.coordinates.

Error message

The code above outputs:
Norm of T:  1.632993161855452
Expanded area:  4.000000000000003
(pyadjiont) Expanded area:  1.0000000000000007
Expanded area:  1.0000000000000007
Norm of T:  0.816496580927726

The bug is the last two lines, and it'd be nice if (pyadjiont) Expanded area: was 4 instead of 1.

Environment:

(firedrake) L2tracking (unrol*) » firedrake-status                                          ~/Documents/fireshape/examples/L2tracking 127 ↵ 
/Users/admp1/Documents/FIREDRAKE/firedrake/bin/firedrake-status:4: DeprecationWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html
  __import__('pkg_resources').require('firedrake==0.13.0+6118.g149f8fda6')
Firedrake Configuration:
    package_manager: True
    minimal_petsc: False
    mpicc: None
    mpicxx: None
    mpif90: None
    mpiexec: None
    disable_ssh: True
    honour_petsc_dir: False
    with_parmetis: False
    slepc: False
    packages: []
    honour_pythonpath: False
    opencascade: False
    tinyasm: False
    torch: False
    petsc_int_type: int32
    cache_dir: /Users/admp1/Documents/FIREDRAKE/firedrake/.cache
    complex: False
    remove_build_files: False
    with_blas: download
    netgen: False
Additions:
    None
Environment:
    PYTHONPATH: None
    PETSC_ARCH: None
    PETSC_DIR: None
Status of components:
---------------------------------------------------------------------------
|Package             |Branch                        |Revision  |Modified  |
---------------------------------------------------------------------------
|FInAT               |master                        |f352325   |False     |
|PyOP2               |master                        |7bef38fa  |False     |
|fiat                |master                        |d0bea63   |False     |
|firedrake           |master                        |149f8fda6 |False     |
|h5py                |firedrake                     |6b512e5e  |False     |
|libsupermesh        |master                        |84becef   |False     |
|loopy               |main                          |967461ba  |False     |
|petsc               |firedrake                     |272f13d92c7|False     |
|pyadjoint           |master                        |908b636   |False     |
|pytest-mpi          |main                          |f2566a1   |False     |
|tsfc                |master                        |021589c   |False     |
|ufl                 |master                        |fbd288e6  |False     |
---------------------------------------------------------------------------
colinjcotter commented 1 month ago

I have already solved this. What you need to do is a bit hacky:

1) Add T to the Controls. 2) Use derivative components to say that you want to zero out the derivatives wrt T. See https://github.com/dolfin-adjoint/pyadjoint/pull/99

APaganini commented 1 month ago

I tried @colinjcotter 's solution but it doesn't work. The following code gives the same output as above

from firedrake import *
from firedrake.adjoint import *

# reference mesh
mesh_r = UnitSquareMesh(5,5)
V = VectorFunctionSpace(mesh_r, "CG", 1)
X = SpatialCoordinate(mesh_r)
T = Function(V).interpolate(X)

# create tape for shape derivatives
continue_annotation()
mesh_m = Mesh(T)
W = VectorFunctionSpace(mesh_m, "CG", 1)
T_m = Function(W)
mesh_m.coordinates.assign(mesh_m.coordinates + T_m)
J = assemble(1*dx(domain=mesh_m))
Jred = ReducedFunctional(J, (Control(T_m), Control(T)),
                         derivative_components=(0,))
stop_annotating()

# failing tests
T *= 2
print("Norm of T: ", norm(T))
print("Expanded area: ", assemble(1*dx(domain=mesh_m)))
# the following line does not notice that mesh_m.coordinates
# has changed, and even worse, it resets T and mesh_m (I don't
# know in which order) to their original values
print("(pyadjiont) Expanded area: ", Jred.__call__([T_m, T]))
print("Expanded area: ", assemble(1*dx(domain=mesh_m)))
print("Norm of T: ", norm(T))
colinjcotter commented 1 month ago

Yeah, I don't know anything about how mesh changes are annotated.

connorjward commented 1 month ago

I am not an expert in the adjoint but the tape does look a bit strange:

tape.pdf

We do annotate changes to the mesh coordinates but I think it's done in quite a fragile way. @dham is really the person to ask about this.

stephankramer commented 3 weeks ago

I think the evaluation of the reduced functional gives exactly the right result here. Jred() should replay the forward model exactly as it was recorded, which is starting with mesh_m at the original coordinates. It will then replay the same steps, and because the control T_m hasn't changed, will arrive at the same value for J. If you want to re-evaluate it with different mesh coordinates you should do that via the control - which is what @colinjcotter is suggesting. The only thing that might be slightly unexpected here is that the replay changes the mesh coordinates in place, i.e. it affects and changes the user's mesh coordinates as it goes through replay - unlike what it does with Functions where it uses independent "checkpoint" Functions to store the inbetween results of the replay. I assume that this is because having an independent "checkpoint" mesh object complicates the transfer between user Functions that are on the original mesh, and checkpoint Functions that would then have to be redefined on this "checkpoint" mesh. The reason it then also changes T is because you've aliased it with mesh.coordinates via mesh_m = Mesh(T). If you don't want that use mesh_m = Mesh(T.copy(deepcopy=True)) but then of course you would have to do mesh.coordinates *= 2 as well.

I suspect the reason your implementation of what @colinjcotter suggested does not work, is because pyadjoint does not correctly handle this aliasing. So instead just specify the mesh itself as the control, i.e.:

from firedrake import *
from firedrake.adjoint import *

# reference mesh
mesh_r = UnitSquareMesh(5,5)
V = VectorFunctionSpace(mesh_r, "CG", 1)
X = SpatialCoordinate(mesh_r)
T = Function(V).interpolate(X)

# create tape for shape derivatives
continue_annotation()
mesh_m = Mesh(T)
W = VectorFunctionSpace(mesh_m, "CG", 1)
T_m = Function(W)
mesh_m.coordinates.assign(mesh_m.coordinates + T_m)
J = assemble(1*dx(domain=mesh_m))
Jred = ReducedFunctional(J, (Control(T_m), Control(mesh_m)),
                         derivative_components=(0,))
stop_annotating()

# failing tests
T *= 2
print("Norm of T: ", norm(T))

print("Expanded area: ", assemble(1*dx(domain=mesh_m)))
print("(pyadjiont) Expanded area: ", Jred([T_m, mesh_m]))
print("Expanded area: ", assemble(1*dx(domain=mesh_m)))
print("Norm of T: ", norm(T))

T /= 2
print("(pyadjiont) Original area: ", Jred([T_m, mesh_m]))
dham commented 3 weeks ago

Stephan has this right. Shape derivatives were a serendipity: we didn't plan for them but more or less discovered that they were possible with minor modifications when Alberto and Florian asked about it. A consequence is that not all the corners of this are fully explored. We definitely don't properly account for aliasing of the coordinate field.

The side effect of changing the mesh is similarly accident rather than design, though fixing it would be difficult for exactly the reason that Stephan highlights. I guess it could be done by having a "checkpoint" mesh as suggested, and then being very careful that our own operations only ever apply to things defined on it. That would be a fair bit of legwork.