firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
498 stars 157 forks source link

BUG: Differentiating a `FormSum` discards `Cofunction` weight #3292

Open jrmaddison opened 9 months ago

jrmaddison commented 9 months ago

Describe the bug Differentiating a FormSum with respect to a Cofunction discards the weight.

Steps to Reproduce The following sets u_star equal to the domain integration cofunction. formsum is a linear combination of a Form and a Cofunction, and is equal to zero.

from firedrake import *
import ufl

mesh = UnitIntervalMesh(10)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)

u = Function(space, name="u").interpolate(Constant(1.0))
u_star = Cofunction(space.dual(), name="u_star")
assemble(inner(Constant(1.0), test) * dx, tensor=u_star)

formsum = 2 * inner(u, test) * dx - 2 * u_star
print(f"{assemble(assemble(formsum)(u))=}")

displaying assemble(assemble(formsum)(u))=-1.1102230246251565e-15. However

der = derivative(formsum, u_star, u_star)
der = ufl.algorithms.expand_derivatives(der)
print(f"{assemble(der(u))=}")

displays assemble(der(u))=1.0, instead of the correct value -2.0.

Similarly print(f"{str(formsum)=}") displays str(formsum)='{ 2 * w₂ * (conj((v_0))) } * dx(<Mesh #1>[everywhere], {})\n + u_star', without the weight.

Expected behavior The weight should be included in the derivative.

Error message No error.

Environment: Ubuntu 22.04, Firedrake built today.

jrmaddison commented 9 months ago

Some related bugs

from firedrake import *
import ufl

mesh = UnitIntervalMesh(10)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
u = Function(space, name="u").interpolate(Constant(1.0))

c = Constant(1.0, name="c")
form = c * assemble(inner(u, test) * dx)

# assert c in ufl.algorithms.extract_type(form, type(c))  # Unexpected fail

form = ufl.replace(form, {c: Constant(2.0)})
# assert tuple(map(float, form.weights())) == (2.0,)  # Unexpected fail
assert tuple(map(float, form.weights())) == (1.0,)  # Unexpected pass