firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
520 stars 160 forks source link

BUG: firedrake.triplot is not working in parallel #3760

Closed diego-hayashi closed 2 months ago

diego-hayashi commented 2 months ago

Describe the bug When trying to plot a mesh with triplot when running in parallel, it returns an error.

Steps to Reproduce Run the following code with mpiexec -n 2 python -u code.py:

from firedrake import *
import matplotlib.pyplot as plt

N = 3
mesh = UnitSquareMesh(N, N)

triplot(mesh)
plt.show()

Expected behavior The triplot function should not have returned an error, and two plots should have appeared on the screen.

Error message The traceback is the following:

Traceback (most recent call last):
  File "code.py", line 7, in <module>
    triplot(mesh)
  File "petsc4py/PETSc/Log.pyx", line 188, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "petsc4py/PETSc/Log.pyx", line 189, in petsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func
  File "/home/opt/firedrake/firedrake/pyplot/mpl.py", line 129, in triplot
    vertices = coords[cell_node_map[:, idx]]
               ~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index 3 is out of bounds for axis 0 with size 3
application called MPI_Abort(PYOP2_COMM_WORLD, 1) - process 0
diego-hayashi commented 2 months ago

It seems to work if the values/coordinates with halos are used in pyplot/mpl.py:

diff --git a/firedrake/pyplot/mpl.py b/firedrake/pyplot/mpl.py
index 984c08f97..4dc90f754 100644
--- a/firedrake/pyplot/mpl.py
+++ b/firedrake/pyplot/mpl.py
@@ -119,12 +119,12 @@ def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}):
         V = VectorFunctionSpace(mesh, element.family(), 1)
         coordinates = assemble(Interpolate(coordinates, V))

-    coords = toreal(coordinates.dat.data_ro, "real")
+    coords = toreal(coordinates.dat.data_ro_with_halos, "real")
     result = []
     interior_kw = dict(interior_kw)
     # If the domain isn't a 3D volume, draw the interior.
     if tdim <= 2:
-        cell_node_map = coordinates.cell_node_map().values
+        cell_node_map = coordinates.cell_node_map().values_with_halo
         idx = (tuple(range(tdim + 1)) if not quad else (0, 1, 3, 2)) + (0,)
         vertices = coords[cell_node_map[:, idx]]

@@ -141,12 +141,12 @@ def triplot(mesh, axes=None, interior_kw={}, boundary_kw={}):
         if typ == "interior":
             facets = mesh.interior_facets
             node_map = coordinates.interior_facet_node_map()
-            node_map = node_map.values[:, :node_map.arity//2]
-            local_facet_ids = facets.local_facet_dat.data_ro[:, :1].reshape(-1)
+            node_map = node_map.values_with_halo[:, :node_map.arity//2]
+            local_facet_ids = facets.local_facet_dat.data_ro_with_halos[:, :1].reshape(-1)
         elif typ == "exterior":
             facets = mesh.exterior_facets
-            local_facet_ids = facets.local_facet_dat.data_ro
-            node_map = coordinates.exterior_facet_node_map().values
+            local_facet_ids = facets.local_facet_dat.data_ro_with_halos
+            node_map = coordinates.exterior_facet_node_map().values_with_halo
         else:
             raise ValueError("Unhandled facet type")
         mask = np.zeros(node_map.shape, dtype=bool)

The resulting plots are shown below:

figure