firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
498 stars 157 forks source link

BUG: Issue with wrapping Firedrake interpolation to vertex-only mesh in PyTorch #3709

Open eikehmueller opened 1 month ago

eikehmueller commented 1 month ago

Describe the bug I tried to wrap the interpolation from a function on the P1 spacefs to a set of points (defined via a DG0 function space vertex_only_fs on a vertex-only mesh) inside a torch.nn.Module. For this I used:

u = Function(fs)
interpolator = interpolate(TestFunction(fs), vertex_only_fs)
self._function_to_patch = fem_operator(ReducedFunctional(assemble(action(interpolator, u)), Control(u)))

and then call this in the forward(self,x) method of a class derived from torch.nn.Module as

return torch.stack([self._function_to_patch(y) for y in torch.unbind(x)])

. The code crashes with the following error message below.

Steps to Reproduce Steps to reproduce the behavior:

  1. Minimal working example: mwe.txt

  2. To run this, use python mwe.py

Expected behavior I would expect that the code completes without crashing.

Error message The final error message is (see error.txt for full error message):

[...]
  File "/Users/eikehmueller/Software/firedrake/src/PyOP2/pyop2/types/data_carrier.py", line 63, in __init__
    self._numpy_data = utils.verify_reshape(data, dtype, shape, allow_none=True)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/eikehmueller/Software/firedrake/src/PyOP2/pyop2/utils.py", line 248, in verify_reshape
    raise DataValueError("Invalid data: expected %d values, got %d!" %
pyop2.exceptions.DataValueError: Invalid data: expected 25 values, got 1!

Add error message with full backtrace, or log. Please add these as text using three backticks (`) for highlighting. Please do not add screenshots unless the bug is purely graphical.

Environment:

Additional Info Discussed this issue with Nacime Bouziani on 25 Jul 2024. In the call, we actually looked at the transpose of the operation (mapping from the vertex-only function space to the original function space), but - contrary to what I thought - the issue also arises in the operation itself, as described above.

eikehmueller commented 1 month ago

The original issue in the interpolation from fs to vertex_only_fs can be resolved by adding continue_annotation() before the with set_working_tape() as _: context manager and a corresponding pause_annotation() after it.

However, the problem persists when the action of the adjoint is considered:

continue_annotation()
with set_working_tape() as _:
    w = Cofunction(vertex_only_fs.dual())
    interpolator = interpolate(TestFunction(fs), vertex_only_fs)
    self._patch_to_function = fem_operator(
    ReducedFunctional(assemble(action(adjoint(interpolator), w)), Control(w)))
pause_annotation()

When I run this (with updated Firedrake from 5 Aug 2024) I get the following warning:

WARNING:root:Adjoint value is None, is the functional independent of the control variable?

See here for an updated full MWE which illustrates the issue.

eikehmueller commented 1 month ago

Some further observations: I added some code in an extended version of the MWE which for a linear layer y=A.x extracts the matrix A (by applying the layer to a sequence of unit vectors we can get the columns of A) and the Jacobian which can be computed with torch.autograd.functional.jacobian(). The latter should be identical to A. If I just use torch.nn.Linear(in_features=3, out_features=7, bias=False) then this all works fine. However, for Encoder(fs, vertex_only_fs) the matrix A extracted in this way and the Jacobian do not agree. Worse, for Decoder(fs, vertex_only_fs) both A and the Jacobian are zero.