mfem / PyMFEM

Python wrapper for MFEM
http://mfem.org
BSD 3-Clause "New" or "Revised" License
215 stars 61 forks source link

Performance issues w/ mesh attribute retrieval #216

Open tradeqvest opened 6 months ago

tradeqvest commented 6 months ago

Hello,

For my application, I constantly need to retrieve mesh element attributes, i.a. mesh.GetElementTransformation(i) or mesh.GetElementVertices(i). As this requires looping over each element, the performance suffers significantly. Is there any way to do it more efficiently that I am overlooking? Is there a way to vectorize the retrieval?

I would appreciate any insights! Thanks in advance for your time and help!

sshiraiwa commented 6 months ago

As for GetElementVertices, there is Mesh::GetVertexToElementTable. This returns a mapping from Vertex to Element as a table. Using I and J array of this table, you can create a reverse mapping from Element to Vertex. In the following, I construct scipy.sparse.csr_matrix from I and J. Then, I took transpose and tocsr You can use the indices and indptr of resultant array as the mapping from element to vertices.

import numpy as np
import mfem.ser as mfem
from scipy.sparse import csr_matrix

mesh = proj.model1.mfem.variables.eval("mesh")
tb = mesh.GetVertexToElementTable()
i = mfem.intArray((tb.GetI(), mesh.GetNV())).GetDataArray()
i = np.hstack((i, tb.Size_of_connections())) # need to append the total length
j = mfem.intArray((tb.GetJ(), tb.Size_of_connections())).GetDataArray()
mat = csr_matrix(([1]*len(j), j, i)).transpose().tocsr()

# well.. let's check if this is correct ;D
for i in range(mesh.GetNE()):
   iverts = mat.indices[mat.indptr[i]:mat.indptr[i+1]]
   iverts2 = mesh.GetElementVertices(i)
   if np.any(np.sort(iverts) != np.sort(iverts2)):
      print("error", i, iverts, iverts2)

As for '''mesh.GetElementTransformation(i)''', I realized that it calls Tr = IsoparametricTransformation() every time, meaning it creates this object every time. We could change the wrapper so that we can pass Tr as a keyword argument, if this object allocation is an issue. If not, I am not sure if there is a simple way to make this faster.

tradeqvest commented 6 months ago

Thank you very much for your answer! 🙂 The first part worked really well!

Regarding the speed up of mesh.GetElementTransformation(i), I want to speed it up for this method:

def interpolate_solution_at_points(
    fespace, mesh, solution, integration_points, corresponding_elements
):
    """
    Interpolate a finite element solution at given points.

    Args:
    - fespace: The finite element space (mfem.FiniteElementSpace)
    - mesh: The mesh (mfem.Mesh)
    - solution: The finite element solution (np.array)
    - points: The points where the solution is to be interpolated (numpy array of shape (n_points, dim))

    Returns:
    - interpolated_values: The interpolated solution values at the given points (numpy array)
    """
    dim = fespace.GetMesh().Dimension()
    assert (
        integration_points.shape[1] == dim
    ), "Dimension of points must match the mesh dimension"
    grid_function = GridFunction(fespace)
    grid_function.Assign(np.ravel(solution))
    n_points = integration_points.shape[0]
    interpolated_values = np.zeros(n_points)
    ip = IntegrationPoint()
    for i, elem in enumerate(corresponding_elements):
        trans = mesh.GetElementTransformation(elem)
        point = Vector(integration_points[i, :])
        trans.TransformBack(point, ip)
        interpolated_values[i] = grid_function.GetValue(elem, ip)
    return interpolated_values.reshape(-1, 1)

If you see a way to make it more efficient, please let me know! 🙂 Thank you in advance for your time and effort!

justinlaughlin commented 5 months ago

Hi @tradeqvest

What is the size of the problem you are working with?

The reason I ask is because if Nel << Npoints it may be worthwhile, as a first pass, to construct a mapping of your transformations for all elements, then access them in the for loop, rather than reinitializing.

I ran a quick profile and it looks like although mesh.GetElementTransformation does take some time, a lot of the time was in the initialization of Vector. Perhaps you could construct a single Vector before your loop, and change the values in the loop.

I'm not aware of a vectorized solution (maybe @sshiraiwa) might know. Could you try those two things and see if it improves your speed?