firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
518 stars 160 forks source link

BUG: `assemble(diagonal=True)` in DG triggers unnecessary halo exchange #2864

Open pbrubeck opened 1 year ago

pbrubeck commented 1 year ago

Assembling the diagonal of a bilinear form in DG is triggering a halo exchange.

Here's the MFE

from firedrake import *

nx = 8
mesh = UnitSquareMesh(nx, nx, quadrilateral=True)
mesh = ExtrudedMesh(mesh, nx)

V = FunctionSpace(mesh, "DQ", degree=7)
a = inner(TestFunction(V), TrialFunction(V))*dx
assemble(a, diagonal=True)

running with mpiexec -n 16 python bug_dg_halo.py -log_view :halo.txt:ascii_flamegraph gives the flamegraph image

The assembly of a linear form correctly omits the halo exchange.

wence- commented 1 year ago

I'm more surprised the linear form doesn't.

connorjward commented 1 year ago

We should always have a local to global transfer after assembly. Otherwise the owned elements will be wrong since they won’t include increments from adjacent processes. I reduced the number of halo exchanges using halo “freezing” but I think there must always be one.

wence- commented 1 year ago

For a DG function space with no facet integrals (and hence no jump terms), cell integrals only contribute locally (and do not write to halo regions) so a sufficiently smart system could determine that is is safe not to perform an exchange. But pyop2 is certainly not that system (since it doesn't know about the discretisation).

connorjward commented 1 year ago

Makes sense. Pablo I wonder if you could do something like:

  1. Check in assemble if the form you are assembling does not modify halo values
  2. Freeze the tensor halo
  3. Perform the assembly. The frozen halo means that the halo exchange is a no-op.
  4. Unfreeze the halo and mark the halo as dirty.
pbrubeck commented 1 year ago

Ah, I think I agree that the halo must be exchanged by default. A different, but yet related, bug that I found was a halo exchange on a dat with empty halos.

from firedrake import *

nx = 8
mesh = UnitSquareMesh(nx, nx, quadrilateral=True, distribution_parameters={
                      "overlap_type": (DistributedMeshOverlapType.NONE, 0)})
mesh = ExtrudedMesh(mesh, nx)

V = FunctionSpace(mesh, "DQ", degree=7)
a = inner(TestFunction(V), TrialFunction(V))*dx
out = assemble(a, diagonal=True)

s0 = out.dat.data_ro.size
s1 = out.dat.data_ro_with_halos.size
print("rank", V.comm.rank, "halo size", s1-s0)
wence- commented 1 year ago

I'm more surprised the linear form doesn't.

I had a quick check, and both things trigger the same exchanges. I add this patch in pyop2:

diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py
index 3969f5b8..8da825e6 100644
--- a/pyop2/types/dat.py
+++ b/pyop2/types/dat.py
@@ -592,6 +592,7 @@ class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC):
         halo = self.dataset.halo
         if halo is None or self._halo_frozen:
             return
+        print("l2g", self, insert_mode)
         halo.local_to_global_begin(self, insert_mode)

     @mpi.collective

And run:

from firedrake import *

mesh = UnitSquareMesh(8, 8)

V = FunctionSpace(mesh, "DG", 1)

u = TrialFunction(V)
v = TestFunction(V)
f = Function(V, name="f")
for _ in range(4):
    A = assemble(u*v*dx, diagonal=True, tensor=f)

That produces:

l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1390b6cd0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1390b6cd0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1390b6cd0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1390b6cd0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1550998b0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1550998b0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1550998b0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x1550998b0 with size 192, with dim (1,)) with datatype float64 Access.INC

In contrast if I make a Function rather than a TrialFunction:

l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12144ea00 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12144ea00 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12144ea00 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12144ea00 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12312afa0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12312afa0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12312afa0 with size 192, with dim (1,)) with datatype float64 Access.INC
l2g OP2 Dat: f on (OP2 DataSet: None_nodes_dset on set OP2 Set: set_#x12312afa0 with size 192, with dim (1,)) with datatype float64 Access.INC
wence- commented 1 year ago

Ah, I think I agree that the halo must be exchanged by default. A different, but yet related, bug that I found was a halo exchange on a dat with empty halos.

from firedrake import *

nx = 8
mesh = UnitSquareMesh(nx, nx, quadrilateral=True, distribution_parameters={
                      "overlap_type": (DistributedMeshOverlapType.NONE, 0)})
mesh = ExtrudedMesh(mesh, nx)

V = FunctionSpace(mesh, "DQ", degree=7)
a = inner(TestFunction(V), TrialFunction(V))*dx
out = assemble(a, diagonal=True)

s0 = out.dat.data_ro.size
s1 = out.dat.data_ro_with_halos.size
print("rank", V.comm.rank, "halo size", s1-s0)

A mesh with no overlap still has a topological halo (there is no cell overlap, but there are still shared entities with codimension > 0). So the point SF is not empty.

If you want these optimisations, you need to implement more discretisation-specific halo construction in Firedrake.