Open connorjward opened 1 year ago
It seems we should make those Firedrake DG coordinates in the callback.
The way I got around this previously was by making a base mesh of the manifold, making the mesh hierarchy of that, periodising each one, then snapping the vertices of the periodised meshes so that they form a nested sequence.
I agree it should be done in a better way.
Anyway, in case it is useful for your development efforts, I attach the code for this (which does an MMS for a Stokes problem set up in Irksome, but that is incidental).
from firedrake import *
from ufl.algorithms.ad import expand_derivatives
from firedrake.petsc import PETSc
from pyop2.datatypes import IntType
import sys
from irksome import LobattoIIIC, Dt, TimeStepper
import numpy
from datetime import datetime
import pprint
print = lambda x: PETSc.Sys.Print(x)
# Parallel distribution parameters; Vanka requires that each process know
# more about its neighbours than the firedrake default
dist_par = {"partition": True,
"overlap_type": (DistributedMeshOverlapType.VERTEX, 2)}
# The RK method and number of stages used
butcher_tableau = LobattoIIIC(2)
ns = butcher_tableau.num_stages
# Initialize t and dt
t = Constant(0.0)
dt = Constant(1.0)
# Set up Problem Class for Stokes Problem
class Problem(object):
def __init__(self, baseN, nref, degree):
super().__init__()
self.baseN = baseN
self.nref = nref
self.degree = degree
@staticmethod
# Making a mesh hierarchy for multigrid in the periodic case is
# currently a little complicated in firedrake. Background:
# firedrake makes an x-periodic mesh by first meshing the surface
# of a cylinder, then appling a coordinate transformation. To
# make a mesh hierarchy of periodic meshes, we have to make our
# own hierarchy of meshes of the cylinder, then periodise each
# one.
def periodise(m):
coord_fs = VectorFunctionSpace(m, "DG", 1, dim=2)
old_coordinates = m.coordinates
new_coordinates = Function(coord_fs)
# make x-periodic mesh
# unravel x coordinates like in periodic interval
# set y coordinates to z coordinates
domain = "{[i, j]: 0 <= i < old_coords.dofs and 0 <= j < new_coords.dofs}"
instructions = """
<float64> Y = 0
<float64> pi = 3.141592653589793
for i
Y = Y + old_coords[i, 1]
end
for j
new_coords[j, 0] = atan2(old_coords[j, 1], old_coords[j, 0]) / (pi* 2)
new_coords[j, 0] = if(new_coords[j, 0] < 0, new_coords[j, 0] + 1, new_coords[j, 0])
new_coords[j, 0] = if(new_coords[j, 0] == 0 and Y < 0, 1, new_coords[j, 0])
new_coords[j, 0] = new_coords[j, 0] * Lx[0]
new_coords[j, 1] = old_coords[j, 2] * Ly[0]
end
"""
cLx = Constant(1)
cLy = Constant(1)
par_loop((domain, instructions), dx,
{"new_coords": (new_coordinates, WRITE),
"old_coords": (old_coordinates, READ),
"Lx": (cLx, READ),
"Ly": (cLy, READ)},
is_loopy_kernel=True)
return Mesh(new_coordinates)
@staticmethod
# Fix errors in mesh coordinates; we know the coordinates
# of the vertices should live on a certain equispaced grid,
# so snap to that
def snap(mesh, N, L=1):
coords = mesh.coordinates.dat.data
coords[...] = numpy.round((N / L) * coords) * (L / N)
# Define mesh
def mesh(self):
# Need a hierarchy of periodic meshes, so create a hierarchy
# of cylinders, then unfold
base = CylinderMesh(self.baseN, self.baseN, 1.0, 1.0,
longitudinal_direction="z",
diagonal="crossed",
distribution_parameters=dist_par)
# Callbacks called before and after mesh refinement.
# This is where we rebalance to try to equidistribute vertices,
# for better parallel load balancing.
def before(dm, i): pass
def after(dm, i):
try:
dm.rebalanceSharedPoints(useInitialGuess=False,
parallel=False, entityDepth=0)
except RuntimeError:
warning("Vertex rebalancing from scratch failed on level %i" % i)
try:
dm.rebalanceSharedPoints(useInitialGuess=True,
parallel=True, entityDepth=0)
except RuntimeError:
warning("Vertex rebalancing from initial guess failed on level %i" % i)
# refine nref times
mh = MeshHierarchy(base, self.nref,
distribution_parameters=dist_par,
callbacks=(before, after))
# Run everybody through Periodise
meshes = tuple(self.periodise(m) for m in mh)
# Reconstruct hierarchy
mh = HierarchyBase(meshes, mh.coarse_to_fine_cells,
mh.fine_to_coarse_cells, nested=True)
snapfactor = 2
# Snap coordinates
for (i, m) in enumerate(mh):
if i > 0:
self.snap(m, self.baseN * 2**i * snapfactor)
# Map onto [-1,1]^2
for mesh in mh:
mesh.coordinates.dat.data[...] = 2 * mesh.coordinates.dat.data - 1
return mh[-1]
# Finite element spaces used: default to degree=1,
# P2 for velocity(u), P1 for pressure (p)
def function_space(self, mesh):
Ve = VectorElement("CG", mesh.ufl_cell(), self.degree + 1)
Pe = FiniteElement("CG", mesh.ufl_cell(), self.degree)
Ze = MixedElement([Ve, Pe])
return FunctionSpace(mesh, Ze)
# Exact solution of Stokes equation
def stokes_soln(self, mesh):
(x, y) = SpatialCoordinate(mesh)
u = as_vector([sin(pi * x) * cos(pi * y) * exp(-2 * t * (pi**2)),
-cos(pi * x) * sin(pi * y) * exp(-2 * t * (pi**2))])
p = Constant(0, domain=mesh)
return (u, p)
# RHS functions for Method of Manufactured Solutions
def stokes_rhs(self, mesh):
(uexact, pexact) = self.stokes_soln(mesh)
u_rhs = expand_derivatives(diff(uexact, t)) - div(grad(uexact)) + grad(pexact)
p_rhs = -div(uexact)
return (u_rhs, p_rhs)
# Initial condition for time stepping
def initial_condition(self, Z):
(x, y) = SpatialCoordinate(Z.mesh())
u0 = as_vector([sin(pi * x) * cos(pi * y),
-cos(pi * x) * sin(pi * y)])
p = Constant(0, domain=mesh)
z = Function(Z)
z.sub(0).interpolate(u0)
z.sub(1).interpolate(p)
return z
# Define Boundary Conditions
def bcs(self, Z):
# We fix the pressure at one vertex on every level.
class PressureFixBC(DirichletBC):
def __init__(self, V, val, subdomain):
super().__init__(V, val, subdomain)
sec = V.dm.getDefaultSection()
dm = V.mesh().topology_dm
coordsSection = dm.getCoordinateSection()
coordsVec = dm.getCoordinatesLocal()
(vStart, vEnd) = dm.getDepthStratum(0)
indices = []
for pt in range(vStart, vEnd):
x = dm.getVecClosure(coordsSection, coordsVec, pt).reshape(-1, 3).mean(axis=0)
# fix [1,0, 0] in unmapped mesh coordinates (bottom left corner)
if (x[0] == 1.0) and (x[1] == 0.0) and (x[2] == 0.0):
if dm.getLabelValue("pyop2_ghost", pt) == -1:
indices = [pt]
break
# Check if this worked!
if V.mesh().comm.size == 1 and len(indices) != 1: assert False
nodes = []
for i in indices:
if sec.getDof(i) > 0:
nodes.append(sec.getOffset(i))
self.nodes = numpy.asarray(nodes, dtype=IntType)
if len(self.nodes) > 0:
sys.stdout.write("Fixing nodes %s on rank %d on dim %d\n" % (self.nodes, V.mesh().comm.rank, V.dim()))
sys.stdout.flush()
# Dirichlet BCs on u, fix one point for p
mesh = Z.mesh()
(u, p) = self.stokes_soln(mesh)
bcs = [
DirichletBC(Z.sub(0), u, "on_boundary"),
PressureFixBC(Z.sub(1), 0, 1)
]
return bcs
# No nullspace due to pressure BC fix
def nullspace(self, Z):
return None
# Variational Form (steady part)
def form(self, z, test_z, Z):
(u, p) = split(z)
(v, q) = split(test_z)
(u_rhs, p_rhs) = self.stokes_rhs(Z.mesh())
F = (
+ inner(grad(u), grad(v)) * dx
- inner(p, div(v)) * dx
- inner(q, div(u)) * dx
- inner(u_rhs, v) * dx
- inner(p_rhs, q) * dx
)
return F
# Main Driver
if __name__ == "__main__":
# Command-line arguments
import argparse
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--nref", type=int, default=2) # Number of refinement levels
parser.add_argument("--baseN", type=int, default=5) # Number of elements on coarsest mesh (in one direction)
parser.add_argument("--degree", type=int, default=1) # Degree used in FE spaces
parser.add_argument("--Tf", type=float, default=0.5) # Final time
args, _ = parser.parse_known_args()
# Set up problem data
problem = Problem(baseN=args.baseN, nref=args.nref, degree=args.degree)
mesh = problem.mesh()
Z = problem.function_space(mesh)
print("Z.dim(): %s" % Z.dim())
nsp = problem.nullspace(Z)
bcs = problem.bcs(Z)
# Solver parameters
N = args.baseN * (2**args.nref)
lin_atol = 0.01 / (N**3)
lin_rtol = 1.0e-8
# Direct solver options
lu = {
"mat_type": "aij",
"snes_type": "newtonls",
"snes_monitor": None,
"ksp_type": "preonly",
"pc_type": "lu",
"pc_factor_mat_solver_type": "mumps",
"snes_linesearch_type": "l2",
"snes_linesearch_monitor": None,
"snes_rtol": lin_rtol,
"snes_atol": lin_atol,
}
# Number of realaxation sweeeps and the Chebyshev bounds used in Vanka
relax_its_vanka = 2
eigenvalue_estimates_vanka = (2, 8)
gmres_max_its = 50
# Depending on number of stages of IRK method, skip blocks in coupled Vanka
if ns == 1:
ind_pressure = "1"
elif ns == 2:
ind_pressure = "1,3"
elif ns == 3:
ind_pressure = "1,3,5"
elif ns == 4:
ind_pressure = "1,3,5,7"
vanka = {
"mat_type": "matfree",
"snes_type": "newtonls",
"snes_monitor": None,
"snes_linesearch_type": "l2",
"snes_linesearch_monitor": None,
"snes_converged_reason": None,
"snes_max_linear_solve_fail": 10,
"snes_stol": 0,
"snes_rtol": lin_rtol,
"snes_atol": lin_atol,
"ksp_type": "fgmres",
"ksp_monitor_true_residual": None,
"ksp_converged_reason": None,
"ksp_atol": lin_atol,
"ksp_rtol": lin_rtol,
"ksp_max_it": gmres_max_its,
"ksp_gmres_restart": gmres_max_its,
"pc_type": "mg",
"pc_mg_cycle_type": "v",
"pc_mg_type": "multiplicative",
"mg_levels_ksp_type": "chebyshev",
"mg_levels_ksp_chebyshev_esteig": "0,0,0,0",
"mg_levels_ksp_chebyshev_eigenvalues": "%s,%s" % eigenvalue_estimates_vanka,
"mg_levels_ksp_max_it": relax_its_vanka,
"mg_levels_ksp_convergence_test": "skip",
"mg_levels_pc_type": "python",
"mg_levels_pc_python_type": "firedrake.PatchPC",
"mg_levels_patch_pc_patch_save_operators": True,
"mg_levels_patch_pc_patch_partition_of_unity": False,
"mg_levels_patch_pc_patch_sub_mat_type": "seqdense",
"mg_levels_patch_pc_patch_construct_dim": 0,
"mg_levels_patch_pc_patch_construct_type": "vanka",
"mg_levels_patch_pc_patch_exclude_subspaces": ind_pressure,
"mg_levels_patch_pc_patch_local_type": "additive",
"mg_levels_patch_pc_patch_precompute_element_tensors": True,
"mg_levels_patch_pc_patch_symmetrise_sweep": False,
"mg_levels_patch_sub_ksp_type": "preonly",
"mg_levels_patch_sub_pc_type": "lu",
"mg_levels_patch_sub_pc_factor_shift_type": "nonzero",
"mg_coarse_ksp_type": "richardson",
"mg_coarse_pc_type": "python",
"mg_coarse_pc_python_type": "firedrake.AssembledPC",
"mg_coarse_assembled_pc_type": "lu",
"mg_coarse_assembled_pc_factor_mat_solver_type": "mumps",
}
sp = vanka
if sp == lu:
print("Solving linear systems using LU")
elif sp == vanka:
print("Solving linear systems using Vanka + multigrid")
# Print solver options, just for data file logging
if mesh.comm.rank == 0:
pprint.pprint(sp)
# Set things up for timestepping from command-line options
T = args.Tf
dt = Constant(args.Tf / N)
# Initial Condition
z = problem.initial_condition(Z)
u, p = split(z)
z_test = TestFunction(Z)
u_test, p_test = split(z_test)
# Initialize IRK timestepping
Func = inner(Dt(u), u_test) * dx + problem.form(z, z_test, Z)
bcs = problem.bcs(Z)
stepper = TimeStepper(Func, butcher_tableau, t, dt, z,
bcs=bcs, solver_parameters=vanka)
# Counters for the output data
total_linear = 0 # Total number of linear iterations
ntimestep = 0 # Number of timesteps taken
# Start total solve time
start = datetime.now()
while (float(t) < T):
# The solver solve
stepper.advance()
# Advance the RK timestep
t.assign(float(t) + float(dt))
# Number of linear iterations per timestep
linear_its = stepper.solver.snes.getLinearSolveIterations()
total_linear += linear_its
ntimestep += 1
end = datetime.now()
time = (end - start).total_seconds() / 60
print(GREEN % ("Time taken: %.2f min in %.2f iterations" % (time, linear_its)))
# Finally, we print out the relative :math:`L^2` error and number of average linear iterations:
(u, p) = z.split()
uexact, pexact = problem.stokes_soln(mesh)
rel_u_norm = norm(u - uexact) / norm(uexact)
print("Relative L2 error in final time solution for u is %.3e" % (rel_u_norm))
abs_p_norm = norm(p - pexact)
print("Absolue L2 error in final time solution for p is %.3e" % (abs_p_norm))
average_linear = total_linear / ntimestep
print(GREEN % ("Average linear iterations per timestep is %.2f" % (average_linear)))
So the reason for not allowing refinement of overlapped meshes (which is what the periodic meshes are) is that if you do that then the halo region gets bigger than necessary (by a lot) on the refined meshes.
The right way to do this nicely is, I think:
Then you can compose the maps appropriately to produce the thing the multigrid hierarchy needs which is the map from owned cells on C_O to owned cells on F_O.
Previously (many moons ago when I wrote the code) DMPlexFilter didn't exist (or I didn't understand), so I simplified the problem generating the cell-to-cell map between a pair of related non-overlapped meshes and then adding overlap on every level.
@pefarrell thank you for the code. I think that it would likely be better to pursue a solution with tighter integration with DMPlex. @wence- this is almost exactly what I had been discussing with @ksagiyam!
Running the following code in parallel
raises the exception
The same error is observed for other periodic meshes like
PeriodicRectangleMesh
.This error is happening because we create these periodic meshes by first creating a manifold mesh and then transforming the coordinates. This means we always have code in each
utility_meshes.py
function like (example):Constructing this
VectorFunctionSpace
callsmesh.init()
which then means that we cannot construct a mesh hierarchy using the topology.