jorgensd / adios4dolfinx

Extending DOLFINx with checkpointing functionality
http://jsdokken.com/adios4dolfinx/
MIT License
19 stars 6 forks source link

checkpointing across three different runs #62

Closed francesco-ballarin closed 6 months ago

francesco-ballarin commented 7 months ago

Current checkpointing strategy assumes that there is a first run that generates the checkpoint data, and a second run that imports it back.

Consider instead the following case, in which there are three separate runs:

  1. generate a mesh, and save it to file mesh_1.bp
  2. load the mesh back from mesh_1.bp and save it to mesh_2.bp
  3. load the mesh back from mesh_2.bp
# test.py
from mpi4py import MPI
import adios4dolfinx
import dolfinx
from pathlib import Path

# Original mesh
comm = MPI.COMM_WORLD
domain = dolfinx.mesh.create_unit_square(
    comm, 2, 2, ghost_mode=dolfinx.mesh.GhostMode.none)

# Write original mesh to mesh_1.bp
out_path_1 = Path("mesh_1.bp")
adios4dolfinx.write_mesh(domain, out_path_1, "BP4")

# Load the mesh back from mesh_1.bp
mesh_1 = adios4dolfinx.read_mesh(
    MPI.COMM_WORLD, out_path_1, engine="BP4", ghost_mode=dolfinx.mesh.GhostMode.none)

# Write the (possibly reordered) mesh to mesh_2.bp
out_path_2 = Path("mesh_2.bp")
adios4dolfinx.write_mesh(mesh_1, out_path_2, "BP4")

# Load the mesh back from mesh_2.bp
mesh_2 = adios4dolfinx.read_mesh(
    MPI.COMM_WORLD, out_path_2, engine="BP4", ghost_mode=dolfinx.mesh.GhostMode.none)

# The print shows that local sizes may be different
print(comm.rank, mesh_1.geometry.index_map().size_local, mesh_2.geometry.index_map().size_local)

# rm -rf *.bp && mpirun -n 3 python3 test.py
# prints
# 0 5 4
# 1 4 5
# 2 0 0

If the user was also writing function checkpoints, when reaching step 3 they would ideally want to load in checkpoints from step 1 and step 2, but on the same mesh (say, mesh_2). However, they can't do that, because mesh_1 and mesh_2 are different due to the mesh partitioner repartitioning the mesh even if it should be well balanced due to the initial partitioning.

Motivates https://github.com/jorgensd/adios4dolfinx/issues/27

jorgensd commented 6 months ago

Introduced in #70

francesco-ballarin commented 5 months ago

As usual, I got confused when trying to use this ;)

To confirm, as you mention in the PR #70 it is necessary to write out the mesh to xdmf and read it back in (i.e., the 3 run case above) to use original checkpointing? Instead, skipping one of the 3 runs is not supposed to work?

Consider the small modification of the original_checkpoint.py demo:

from pathlib import Path
from typing import Tuple, Union

from mpi4py import MPI
import numpy as np

import dolfinx
import adios4dolfinx

def create_xdmf_mesh(filename: Path):
    comm = MPI.COMM_WORLD
    mesh = dolfinx.mesh.create_unit_square(comm, 4 * comm.size, 4 * comm.size, ghost_mode=dolfinx.mesh.GhostMode.shared_facet)
    with dolfinx.io.XDMFFile(comm, filename.with_suffix(".xdmf"), "w") as xdmf:
        xdmf.write_mesh(mesh)
    print(f"{mesh.comm.rank+1}/{mesh.comm.size} Mesh written to {filename.with_suffix('.xdmf')}")
    return mesh

mesh_file = Path("MyMesh.xdmf")
mesh = create_xdmf_mesh(mesh_file)
# mesh_for_write_function = dolfinx.io.XDMFFile(MPI.COMM_WORLD, mesh_file, "r").read_mesh  # WORKS - 3 runs
mesh_for_write_function = lambda: mesh  # DOES NOT WORK - 2 runs
mesh_for_verify_checkpoint = dolfinx.io.XDMFFile(MPI.COMM_WORLD, mesh_file, "r").read_mesh
print(mesh_for_write_function)
print(mesh_for_verify_checkpoint)

# Next, we will create a function on the mesh and write it to a checkpoint.

def f(x):
    return x[0]**2 + x[1]**3

def write_function(
    mesh_reader, function_filename: Path, element: Tuple[str, int, Tuple[int,]]
):
    mesh = mesh_reader()
    V = dolfinx.fem.functionspace(mesh, element)
    u = dolfinx.fem.Function(V)
    u.interpolate(f)
    print(u.x.array)

    adios4dolfinx.write_function_on_input_mesh(
        function_filename.with_suffix(".bp"),
        u,
        "bp4"
    )
    print(
        f"{mesh.comm.rank+1}/{mesh.comm.size} Function written to ",
        f"{function_filename.with_suffix('.bp')}",
    )

# Read in mesh and write function to file

element = ("CG", 1)
function_file = Path("MyFunction.bp")
write_function(mesh_for_write_function, function_file, element)

# Finally, we will read in the mesh from file and the function from the checkpoint
# and compare it with the analytical solution.

def verify_checkpoint(
    mesh_reader, function_filename: Path, element: Tuple[str, int, Tuple[int,]]
):
    in_mesh = mesh_reader()
    V = dolfinx.fem.functionspace(in_mesh, element)
    u_ex = dolfinx.fem.Function(V)
    u_ex.interpolate(f)
    u_in = dolfinx.fem.Function(V)
    adios4dolfinx.read_function(function_filename.with_suffix(".bp"), u_in, "bp4")

    np.testing.assert_allclose(u_in.x.array, u_ex.x.array)
    print(
        "Successfully read checkpoint onto mesh on rank ",
        f"{in_mesh.comm.rank + 1}/{in_mesh.comm.size}",
    )

# Verify checkpoint by comparing to exact solution
verify_checkpoint(mesh_for_verify_checkpoint, function_file, element)

Everything works as expected uncommenting the line that ends with WORKS - 3 runs and commenting the one below. Leaving the code unchanged (i.e., using the lambda function as written on the line that ends with "DOES NOT WORK - 2 runs") results in

Mismatched elements: 24 / 25 (96%)
Max absolute difference: 1.328125
Max relative difference: 43.
 x: array([0.25    , 0.0625  , 0.078125, 0.578125, 1.      , 0.078125,
       0.375   , 1.015625, 0.578125, 0.6875  , 0.6875  , 0.671875,
       1.125   , 0.125   , 0.671875, 0.6875  , 0.671875, 1.0625  ,...
 y: array([0.5625  , 1.      , 1.015625, 0.578125, 0.25    , 1.125   ,
       0.265625, 0.6875  , 0.0625  , 1.421875, 0.078125, 0.375   ,
       0.984375, 0.      , 2.      , 0.015625, 0.1875  , 0.671875,...

running in serial with

rm -rf My* && python3 original_checkpoint.py
jorgensd commented 5 months ago

The key here is that when DOLFINx creates a unit square, the points and cells gets re-ordered compared to the original input data. You can see this re-ordering in the geometry.global_input_indices, and topology.original_cell_index.

i.e. your code work work with either:

mesh = create_xdmf_mesh(mesh_file)
mesh_for_write_function = lambda: mesh
mesh_for_verify_checkpoint = lambda: create_xdmf_mesh(mesh_file)

or

mesh = create_xdmf_mesh(mesh_file)
mesh_for_write_function = lambda: mesh
mesh_for_verify_checkpoint = lambda: mesh

(I've not tested this, but feel free to verify it).