firedrakeproject / firedrake

Firedrake is an automated system for the portable solution of partial differential equations using the finite element method (FEM)
https://firedrakeproject.org
Other
498 stars 157 forks source link

BUG: Periodic utility meshes cannot create mesh hierarchies in parallel #2984

Open connorjward opened 1 year ago

connorjward commented 1 year ago

Running the following code in parallel

from firedrake import *

mesh = PeriodicIntervalMesh(3, 1)
mh = MeshHierarchy(mesh, 3)

raises the exception

(firedrake) [connor@elitebook840 firedrake]$ mpiexec -np 2 python mytest.py 
Traceback (most recent call last):
  File "/home/connor/.local/opt/firedrake/src/firedrake/mytest.py", line 6, in <module>
Traceback (most recent call last):
  File "/home/connor/.local/opt/firedrake/src/firedrake/mytest.py", line 6, in <module>
    mh = MeshHierarchy(mesh, 3)
  File "/home/connor/.local/opt/firedrake/src/firedrake/firedrake/mg/mesh.py", line 109, in MeshHiera
rchy
    mh = MeshHierarchy(mesh, 3)
  File "/home/connor/.local/opt/firedrake/src/firedrake/firedrake/mg/mesh.py", line 109, in MeshHiera
rchy
    raise RuntimeError("Cannot refine parallel overlapped meshes "
RuntimeError: Cannot refine parallel overlapped meshes (make sure the MeshHierarchy is built immediat
ely after the Mesh)
application called MPI_Abort(PYOP2_COMM_WORLD, 1) - process 0
    raise RuntimeError("Cannot refine parallel overlapped meshes "
RuntimeError: Cannot refine parallel overlapped meshes (make sure the MeshHierarchy is built immediat
ely after the Mesh)
application called MPI_Abort(PYOP2_COMM_WORLD, 1) - process 1

The same error is observed for other periodic meshes like PeriodicRectangleMesh.

This error is happening because we create these periodic meshes by first creating a manifold mesh and then transforming the coordinates. This means we always have code in each utility_meshes.py function like (example):

    coord_fs = VectorFunctionSpace(
        m, FiniteElement("DG", interval, 1, variant="equispaced"), dim=1
    )
    old_coordinates = m.coordinates
    new_coordinates = Function(
        coord_fs, name=mesh._generate_default_mesh_coordinates_name(name)
    )

Constructing this VectorFunctionSpace calls mesh.init() which then means that we cannot construct a mesh hierarchy using the topology.

ksagiyam commented 1 year ago

It seems we should make those Firedrake DG coordinates in the callback.

pefarrell commented 1 year ago

The way I got around this previously was by making a base mesh of the manifold, making the mesh hierarchy of that, periodising each one, then snapping the vertices of the periodised meshes so that they form a nested sequence.

I agree it should be done in a better way.

Anyway, in case it is useful for your development efforts, I attach the code for this (which does an MMS for a Stokes problem set up in Irksome, but that is incidental).

from firedrake import *
from ufl.algorithms.ad import expand_derivatives
from firedrake.petsc import PETSc
from pyop2.datatypes import IntType
import sys
from irksome import LobattoIIIC, Dt, TimeStepper
import numpy
from datetime import datetime
import pprint
print = lambda x: PETSc.Sys.Print(x)

# Parallel distribution parameters; Vanka requires that each process know
# more about its neighbours than the firedrake default
dist_par = {"partition": True,
            "overlap_type": (DistributedMeshOverlapType.VERTEX, 2)}

# The RK method and number of stages used
butcher_tableau = LobattoIIIC(2)
ns = butcher_tableau.num_stages

# Initialize t and dt
t = Constant(0.0)
dt = Constant(1.0)

# Set up Problem Class for Stokes Problem
class Problem(object):
    def __init__(self, baseN, nref, degree):
        super().__init__()
        self.baseN = baseN
        self.nref = nref
        self.degree = degree

    @staticmethod
    # Making a mesh hierarchy for multigrid in the periodic case is
    # currently a little complicated in firedrake.  Background:
    # firedrake makes an x-periodic mesh by first meshing the surface
    # of a cylinder, then appling a coordinate transformation.  To
    # make a mesh hierarchy of periodic meshes, we have to make our
    # own hierarchy of meshes of the cylinder, then periodise each
    # one.
    def periodise(m):
        coord_fs = VectorFunctionSpace(m, "DG", 1, dim=2)
        old_coordinates = m.coordinates
        new_coordinates = Function(coord_fs)

        # make x-periodic mesh
        # unravel x coordinates like in periodic interval
        # set y coordinates to z coordinates
        domain = "{[i, j]: 0 <= i < old_coords.dofs and 0 <= j < new_coords.dofs}"
        instructions = """
        <float64> Y = 0
        <float64> pi = 3.141592653589793
        for i
            Y = Y + old_coords[i, 1]
        end
        for j
            new_coords[j, 0] = atan2(old_coords[j, 1], old_coords[j, 0]) / (pi* 2)
            new_coords[j, 0] = if(new_coords[j, 0] < 0, new_coords[j, 0] + 1, new_coords[j, 0])
            new_coords[j, 0] = if(new_coords[j, 0] == 0 and Y < 0, 1, new_coords[j, 0])
            new_coords[j, 0] = new_coords[j, 0] * Lx[0]
            new_coords[j, 1] = old_coords[j, 2] * Ly[0]
        end
        """

        cLx = Constant(1)
        cLy = Constant(1)

        par_loop((domain, instructions), dx,
                 {"new_coords": (new_coordinates, WRITE),
                  "old_coords": (old_coordinates, READ),
                  "Lx": (cLx, READ),
                  "Ly": (cLy, READ)},
                 is_loopy_kernel=True)

        return Mesh(new_coordinates)

    @staticmethod
    # Fix errors in mesh coordinates; we know the coordinates
    # of the vertices should live on a certain equispaced grid,
    # so snap to that
    def snap(mesh, N, L=1):
        coords = mesh.coordinates.dat.data
        coords[...] = numpy.round((N / L) * coords) * (L / N)

    # Define mesh
    def mesh(self):
        # Need a hierarchy of periodic meshes, so create a hierarchy
        # of cylinders, then unfold
        base = CylinderMesh(self.baseN, self.baseN, 1.0, 1.0,
                            longitudinal_direction="z",
                            diagonal="crossed",
                            distribution_parameters=dist_par)

        # Callbacks called before and after mesh refinement.
        # This is where we rebalance to try to equidistribute vertices,
        # for better parallel load balancing.
        def before(dm, i): pass
        def after(dm, i):
            try:
                dm.rebalanceSharedPoints(useInitialGuess=False,
                                         parallel=False, entityDepth=0)
            except RuntimeError:
                warning("Vertex rebalancing from scratch failed on level %i" % i)
            try:
                dm.rebalanceSharedPoints(useInitialGuess=True,
                                         parallel=True, entityDepth=0)
            except RuntimeError:
                warning("Vertex rebalancing from initial guess failed on level %i" % i)

        # refine nref times
        mh = MeshHierarchy(base, self.nref,
                           distribution_parameters=dist_par,
                           callbacks=(before, after))

        # Run everybody through Periodise
        meshes = tuple(self.periodise(m) for m in mh)
        # Reconstruct hierarchy
        mh = HierarchyBase(meshes, mh.coarse_to_fine_cells,
                           mh.fine_to_coarse_cells, nested=True)

        snapfactor = 2
        # Snap coordinates
        for (i, m) in enumerate(mh):
            if i > 0:
                self.snap(m, self.baseN * 2**i * snapfactor)

        # Map onto [-1,1]^2
        for mesh in mh:
            mesh.coordinates.dat.data[...] = 2 * mesh.coordinates.dat.data - 1

        return mh[-1]

    # Finite element spaces used: default to degree=1,
    # P2 for velocity(u), P1 for pressure (p)
    def function_space(self, mesh):
        Ve = VectorElement("CG", mesh.ufl_cell(), self.degree + 1)
        Pe = FiniteElement("CG", mesh.ufl_cell(), self.degree)
        Ze = MixedElement([Ve, Pe])

        return FunctionSpace(mesh, Ze)

    # Exact solution of Stokes equation
    def stokes_soln(self, mesh):
        (x, y) = SpatialCoordinate(mesh)
        u = as_vector([sin(pi * x) * cos(pi * y) * exp(-2 * t * (pi**2)),
                       -cos(pi * x) * sin(pi * y) * exp(-2 * t * (pi**2))])
        p = Constant(0, domain=mesh)

        return (u, p)

    # RHS functions for Method of Manufactured Solutions
    def stokes_rhs(self, mesh):
        (uexact, pexact) = self.stokes_soln(mesh)
        u_rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + grad(pexact)
        p_rhs = -div(uexact)

        return (u_rhs, p_rhs)

    # Initial condition for time stepping
    def initial_condition(self, Z):
        (x, y) = SpatialCoordinate(Z.mesh())

        u0 = as_vector([sin(pi * x) * cos(pi * y),
                        -cos(pi * x) * sin(pi * y)])
        p = Constant(0, domain=mesh)

        z = Function(Z)
        z.sub(0).interpolate(u0)
        z.sub(1).interpolate(p)
        return z

    # Define Boundary Conditions
    def bcs(self, Z):
        # We fix the pressure at one vertex on every level.
        class PressureFixBC(DirichletBC):
            def __init__(self, V, val, subdomain):
                super().__init__(V, val, subdomain)
                sec = V.dm.getDefaultSection()
                dm = V.mesh().topology_dm

                coordsSection = dm.getCoordinateSection()
                coordsVec = dm.getCoordinatesLocal()

                (vStart, vEnd) = dm.getDepthStratum(0)
                indices = []
                for pt in range(vStart, vEnd):
                    x = dm.getVecClosure(coordsSection, coordsVec, pt).reshape(-1, 3).mean(axis=0)
                    # fix [1,0, 0] in unmapped mesh coordinates (bottom left corner)
                    if (x[0] == 1.0) and (x[1] == 0.0) and (x[2] == 0.0):
                        if dm.getLabelValue("pyop2_ghost", pt) == -1:
                            indices = [pt]
                        break

                # Check if this worked!
                if V.mesh().comm.size == 1 and len(indices) != 1: assert False

                nodes = []
                for i in indices:
                    if sec.getDof(i) > 0:
                        nodes.append(sec.getOffset(i))

                self.nodes = numpy.asarray(nodes, dtype=IntType)
                if len(self.nodes) > 0:
                    sys.stdout.write("Fixing nodes %s on rank %d on dim %d\n" % (self.nodes, V.mesh().comm.rank, V.dim()))
                    sys.stdout.flush()

        # Dirichlet BCs on u, fix one point for p
        mesh = Z.mesh()
        (u, p) = self.stokes_soln(mesh)
        bcs = [
               DirichletBC(Z.sub(0), u, "on_boundary"),
               PressureFixBC(Z.sub(1), 0, 1)
              ]
        return bcs

    # No nullspace due to pressure BC fix
    def nullspace(self, Z):
        return None

    # Variational Form (steady part)
    def form(self, z, test_z, Z):
        (u, p) = split(z)
        (v, q) = split(test_z)
        (u_rhs, p_rhs) = self.stokes_rhs(Z.mesh())
        F = (
              + inner(grad(u), grad(v)) * dx
              - inner(p, div(v)) * dx
              - inner(q, div(u)) * dx
              - inner(u_rhs, v) * dx
              - inner(p_rhs, q) * dx
            )

        return F

# Main Driver
if __name__ == "__main__":

    # Command-line arguments
    import argparse
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument("--nref", type=int, default=2)  # Number of refinement levels
    parser.add_argument("--baseN", type=int, default=5)  # Number of elements on coarsest mesh (in one direction)
    parser.add_argument("--degree", type=int, default=1)  # Degree used in FE spaces
    parser.add_argument("--Tf", type=float, default=0.5)  # Final time

    args, _ = parser.parse_known_args()

    # Set up problem data
    problem = Problem(baseN=args.baseN, nref=args.nref, degree=args.degree)
    mesh = problem.mesh()
    Z = problem.function_space(mesh)
    print("Z.dim(): %s" % Z.dim())

    nsp = problem.nullspace(Z)
    bcs = problem.bcs(Z)

    # Solver parameters
    N = args.baseN * (2**args.nref)
    lin_atol = 0.01 / (N**3)
    lin_rtol = 1.0e-8

    # Direct solver options
    lu = {
       "mat_type": "aij",
       "snes_type": "newtonls",
       "snes_monitor": None,
       "ksp_type": "preonly",
       "pc_type": "lu",
       "pc_factor_mat_solver_type": "mumps",
       "snes_linesearch_type": "l2",
       "snes_linesearch_monitor": None,
       "snes_rtol": lin_rtol,
       "snes_atol": lin_atol,
         }

    # Number of realaxation sweeeps and the Chebyshev bounds used in Vanka
    relax_its_vanka = 2
    eigenvalue_estimates_vanka = (2, 8)
    gmres_max_its = 50

    # Depending on number of stages of IRK method, skip blocks in coupled Vanka
    if ns == 1:
        ind_pressure = "1"
    elif ns == 2:
        ind_pressure = "1,3"
    elif ns == 3:
        ind_pressure = "1,3,5"
    elif ns == 4:
        ind_pressure = "1,3,5,7"

    vanka = {
           "mat_type": "matfree",
           "snes_type": "newtonls",
           "snes_monitor": None,
           "snes_linesearch_type": "l2",
           "snes_linesearch_monitor": None,
           "snes_converged_reason": None,
           "snes_max_linear_solve_fail": 10,
           "snes_stol": 0,
           "snes_rtol": lin_rtol,
           "snes_atol": lin_atol,
           "ksp_type": "fgmres",
           "ksp_monitor_true_residual": None,
           "ksp_converged_reason": None,
           "ksp_atol": lin_atol,
           "ksp_rtol": lin_rtol,
           "ksp_max_it": gmres_max_its,
           "ksp_gmres_restart": gmres_max_its,
           "pc_type": "mg",
           "pc_mg_cycle_type": "v",
           "pc_mg_type": "multiplicative",
           "mg_levels_ksp_type": "chebyshev",
           "mg_levels_ksp_chebyshev_esteig": "0,0,0,0",
           "mg_levels_ksp_chebyshev_eigenvalues": "%s,%s" % eigenvalue_estimates_vanka,
           "mg_levels_ksp_max_it": relax_its_vanka,
           "mg_levels_ksp_convergence_test": "skip",
           "mg_levels_pc_type": "python",
           "mg_levels_pc_python_type": "firedrake.PatchPC",
           "mg_levels_patch_pc_patch_save_operators": True,
           "mg_levels_patch_pc_patch_partition_of_unity": False,
           "mg_levels_patch_pc_patch_sub_mat_type": "seqdense",
           "mg_levels_patch_pc_patch_construct_dim": 0,
           "mg_levels_patch_pc_patch_construct_type": "vanka",
           "mg_levels_patch_pc_patch_exclude_subspaces": ind_pressure,
           "mg_levels_patch_pc_patch_local_type": "additive",
           "mg_levels_patch_pc_patch_precompute_element_tensors": True,
           "mg_levels_patch_pc_patch_symmetrise_sweep": False,
           "mg_levels_patch_sub_ksp_type": "preonly",
           "mg_levels_patch_sub_pc_type": "lu",
           "mg_levels_patch_sub_pc_factor_shift_type": "nonzero",
           "mg_coarse_ksp_type": "richardson",
           "mg_coarse_pc_type": "python",
           "mg_coarse_pc_python_type": "firedrake.AssembledPC",
           "mg_coarse_assembled_pc_type": "lu",
           "mg_coarse_assembled_pc_factor_mat_solver_type": "mumps",
                     }

    sp = vanka
    if sp == lu:
        print("Solving linear systems using LU")
    elif sp == vanka:
        print("Solving linear systems using Vanka + multigrid")

    # Print solver options, just for data file logging
    if mesh.comm.rank == 0:
        pprint.pprint(sp)

    # Set things up for timestepping from command-line options
    T = args.Tf
    dt = Constant(args.Tf / N)

    # Initial Condition
    z = problem.initial_condition(Z)
    u, p = split(z)
    z_test = TestFunction(Z)
    u_test, p_test = split(z_test)

    # Initialize IRK timestepping
    Func = inner(Dt(u), u_test) * dx + problem.form(z, z_test, Z)
    bcs = problem.bcs(Z)
    stepper = TimeStepper(Func, butcher_tableau, t, dt, z,
                          bcs=bcs, solver_parameters=vanka)

    # Counters for the output data
    total_linear = 0  # Total number of linear iterations
    ntimestep = 0  # Number of timesteps taken

    # Start total solve time
    start = datetime.now()

    while (float(t) < T):

        # The solver solve
        stepper.advance()

        # Advance the RK timestep
        t.assign(float(t) + float(dt))

        # Number of linear iterations per timestep
        linear_its = stepper.solver.snes.getLinearSolveIterations()
        total_linear += linear_its
        ntimestep += 1

    end = datetime.now()
    time = (end - start).total_seconds() / 60
    print(GREEN % ("Time taken: %.2f min in %.2f iterations" % (time, linear_its)))

# Finally, we print out the relative :math:`L^2` error and number of average linear iterations:
(u, p) = z.split()
uexact, pexact = problem.stokes_soln(mesh)
rel_u_norm = norm(u - uexact) / norm(uexact)
print("Relative L2 error in final time solution for u is %.3e" % (rel_u_norm))
abs_p_norm = norm(p - pexact)
print("Absolue L2 error in final time solution for p is %.3e" % (abs_p_norm))
average_linear = total_linear / ntimestep
print(GREEN % ("Average linear iterations per timestep is %.2f" % (average_linear)))
wence- commented 1 year ago

So the reason for not allowing refinement of overlapped meshes (which is what the periodic meshes are) is that if you do that then the halo region gets bigger than necessary (by a lot) on the refined meshes.

The right way to do this nicely is, I think:

  1. start with an overlapped coarse mesh (C_O)
  2. DMPlexFilter it to produce a non-overlapped mesh (C_N), remembering the pointSF that pushes forward from C_O to C_N
  3. Refine the non-overlapped mesh to produce a fine mesh (F_N), remembering the point numbering map (this already exists)
  4. Overlap the fine mesh to produce F_O. Again remembering the pointSF that pushes forward from F_N to F_O.

Then you can compose the maps appropriately to produce the thing the multigrid hierarchy needs which is the map from owned cells on C_O to owned cells on F_O.

Previously (many moons ago when I wrote the code) DMPlexFilter didn't exist (or I didn't understand), so I simplified the problem generating the cell-to-cell map between a pair of related non-overlapped meshes and then adding overlap on every level.

connorjward commented 1 year ago

@pefarrell thank you for the code. I think that it would likely be better to pursue a solution with tighter integration with DMPlex. @wence- this is almost exactly what I had been discussing with @ksagiyam!