kinnala / scikit-fem

Simple finite element assemblers
https://scikit-fem.readthedocs.io
BSD 3-Clause "New" or "Revised" License
509 stars 80 forks source link

Parallelizing the assembly over the elements #450

Closed kinnala closed 3 years ago

kinnala commented 4 years ago

I'm investigating the feasibility of parallelizing assembly over elements.

Take the following snippet (from #439 ):

import pygmsh as pm
import numpy as np
from skfem.io import from_meshio
from skfem.helpers import ddot, grad, dot, transpose, prod
import matplotlib.pyplot as plt
from math import pi
from skfem import *

def vdet(A):
    detA = np.zeros_like(A[0, 0])
    detA = A[0, 0] * (A[1, 1] * A[2, 2] -
                      A[1, 2] * A[2, 1]) -\
           A[0, 1] * (A[1, 0] * A[2, 2] -
                      A[1, 2] * A[2, 0]) +\
           A[0, 2] * (A[1, 0] * A[2, 1] -
                      A[1, 1] * A[2, 0])
    return detA

def vinv(A):
    invA = np.zeros_like(A)
    detA = vdet(A)
    invA[0, 0] = (-A[1, 2] * A[2, 1] +
                  A[1, 1] * A[2, 2]) / detA
    invA[1, 0] = (A[1, 2] * A[2, 0] -
                  A[1, 0] * A[2, 2]) / detA
    invA[2, 0] = (-A[1, 1] * A[2, 0] +
                  A[1, 0] * A[2, 1]) / detA
    invA[0, 1] = (A[0, 2] * A[2, 1] -
                  A[0, 1] * A[2, 2]) / detA
    invA[1, 1] = (-A[0, 2] * A[2, 0] +
                  A[0, 0] * A[2, 2]) / detA
    invA[2, 1] = (A[0, 1] * A[2, 0] -
                  A[0, 0] * A[2, 1]) / detA
    invA[0, 2] = (-A[0, 2] * A[1, 1] +
                  A[0, 1] * A[1, 2]) / detA
    invA[1, 2] = (A[0, 2] * A[1, 0] -
                  A[0, 0] * A[1, 2]) / detA
    invA[2, 2] = (-A[0, 1] * A[1, 0] +
                  A[0, 0] * A[1, 1]) / detA
    return invA, detA

def firstPKStress(u):
    F = grad(u)
    F[0,0] += 1.
    F[1,1] += 1.
    F[2,2] += 1.
    J = vdet(F)
    invF, _ = vinv(F)
    return mu * F - mu * transpose(invF) + lmbda * J * (J - 1) * transpose(invF)

def jacobianPK(u):
    F = grad(u)
    eye = np.zeros_like(F)
    for i in range(3):
        F[i,i] += 1.
        eye[i,i] += 1.
    Finv, J = vinv(F)
    dFdF = np.einsum("ik...,jl...->ijkl...", eye, eye)
    dFinvdF = np.einsum("jk...,li...->ijkl...", Finv, Finv)
    C = mu * dFdF + mu * dFinvdF -\
        lmbda * J * (J - 1) * dFinvdF +\
        lmbda * (2 * J - 1) * J * np.einsum("ji...,lk...->ijkl...", Finv, Finv)
    return C

mesh = MeshTet()
mesh.refine(4)
elem = ElementTetP1()
uelem = ElementVectorH1(elem)
iBasis = InteriorBasis(mesh, uelem)
fBasis = FacetBasis(mesh, uelem)
u = np.zeros(iBasis.N) #this takes care of dimension

# materialParams and init
bodyForce = np.array([0., -1./2, 0])
E, nu = 10., 0.3
mu = E/2/(1+nu)
lmbda = 2*mu*nu/(1-2*nu)
dofs = {
    "left": iBasis.get_dofs(lambda x: x[0]==0),
    "right": iBasis.get_dofs(lambda x: x[0]==1.)
}

# assign DirichletBC
# variables used in the FEniCS demo
scale = y0 = z0 = 0.5
theta = pi/3.

# scaling factor: bta: for Newton's method'
bta = 0.7

u1Right = 0.
u2Right = lambda x,y,z: scale*(y0 + (y - y0)*np.cos(theta) - (z - z0)*np.sin(theta) - y)
u3Right = lambda x,y,z: scale*(z0 + (y - y0)*np.sin(theta) + (z - z0)*np.cos(theta) - z)

rightNodes = mesh.p[:,mesh.nodes_satisfying(lambda x: np.isclose(x[0], 1.))]
leftNodes = mesh.p[:,mesh.nodes_satisfying(lambda x: np.isclose(x[0], 0.))]

u[dofs["left"].nodal['u^1']] = 0.
u[dofs["left"].nodal['u^2']] = 0.
u[dofs["left"].nodal['u^3']] = 0.
u[dofs["right"].nodal['u^1']] = 0.
u[dofs["right"].nodal['u^2']] = u2Right(*iBasis.doflocs[:, dofs["right"].nodal['u^2']])
u[dofs["right"].nodal['u^3']] = u3Right(*iBasis.doflocs[:, dofs["right"].nodal['u^3']])

I = iBasis.complement_dofs(dofs)

@LinearForm
def rhs(v, w):
    return ddot(firstPKStress(w["w"]), grad(v)) #+ dot(bodyForce, v)

@BilinearForm
def jac(u, v, w):
    return np.einsum('ijkl...,ij...,kl...', jacobianPK(w["w"]), grad(u), grad(v))

w = iBasis.interpolate(u)

Assembly takes quite a while because so many FP operations are done inside the form:

In [3]: %time     J = asm(jac, iBasis, w=w)                                                                                                         
CPU times: user 12.4 s, sys: 4.01 s, total: 16.4 s
Wall time: 16.4 s

This is despite there being only about 4k DOF's (Edit: 3 * 4k = 12 k DOF's)

In [12]: mesh.p.shape                                                                                                                               
Out[12]: (3, 4233)

What happens if we assemble only half of the elements?


In [6]: ib1 = InteriorBasis(mesh, uelem, elements=mesh.elements_satisfying(lambda x: x[0]<0.5))                                                     

In [7]: ib2 = InteriorBasis(mesh, uelem, elements=mesh.elements_satisfying(lambda x: x[0]>0.5)) 

In [9]: w1 = ib1.interpolate(u)                                                                                                                     

In [10]: w2 = ib2.interpolate(u)                                                                                                                    

In [11]: %time asm(jac, ib1, w=w1)                                                                                                                  
CPU times: user 5.99 s, sys: 1.92 s, total: 7.91 s
Wall time: 7.92 s
Out[11]: 
<12699x12699 sparse matrix of type '<class 'numpy.float64'>'
    with 260207 stored elements in Compressed Sparse Row format>

Looking at htop while this is running, we see that assembly (in this case) is done mostly using single core.

So we could try assembling ib1 and ib2 in parallel and save some time.

Let's try this using dask.

In [6]: import dask.bag as db                                                                                                                       

In [7]: b = db.from_sequence([(ib1, w1), (ib2, w2)])                                                                                                

In [9]: c = b.map(lambda x: asm(jac, x[0], w=x[1])) 

In [12]: %time c.compute()                                                                                                                          
CPU times: user 83.7 ms, sys: 236 ms, total: 320 ms
Wall time: 12.6 s
Out[12]: 
[<12699x12699 sparse matrix of type '<class 'numpy.float64'>'
    with 260207 stored elements in Compressed Sparse Row format>,
 <12699x12699 sparse matrix of type '<class 'numpy.float64'>'
    with 260693 stored elements in Compressed Sparse Row format>]

12.6 s < 16.4 s

So we seem to have a chance of saving some seconds by splitting the assembly over elements.

Remaining questions:

I'll make a branch which explores this when I have time.

ahojukka5 commented 4 years ago

If you want to optimize for performance, I would suggest profiling the memory usage of a single thread assembly first. There might be some easy ways to get the code faster.

kinnala commented 4 years ago

Could be, but I'm pretty happy with the single thread performance for my own use cases, and was only surprised in #439 by the amount of overhead generated by evaluating a complicated form. I've been previously doing plenty of profiling outside of the form evaluation parts of the code and I'm quite convinced that there isn't massive gains (in terms of processor time) there. Especially now that almost everything gets precomputed in InteriorBasis and initializing it is not a bottleneck AFAIK; asm is basically only evaluating the user-given form using the huge number of values precomputed in InteriorBasis as parameters.

Everything that happens within a form is basically multiplication and summation of large 2D NumPy arrays, and I don't know how much we're really able to improve that because the forms are provided by the user. On the otherhand, in terms of memory usage there could be plenty of improvements.

Edit: I'm happy to be proven wrong though :-)

kinnala commented 4 years ago

I mean, comparing the numbers in the README for Laplace equation and the numbers for the Jacobian for Neohookean hyperelasticity (both evaluated by my laptop), it's obvious that the overhead comes from the form, especially since same element is used in both cases.

One option would be to try to JIT compile the entire form using Numba or similar.

kinnala commented 4 years ago

I tried to JIT compile the above form using Numba:

In [3]: %time     J = asm(jac2, iBasis, w=w)                                                                                                        
CPU times: user 3.98 s, sys: 19.9 ms, total: 4 s
Wall time: 4.01 s

So 4x faster. But the form is 100x more ugly:

@jit(nopython=True)
def numbajac(du, dv, dw):
    out = np.zeros_like(du[0, 0])
    for a in range(dw.shape[2]):
        for b in range(dw.shape[3]):
            dw[0, 0, a, b] += 1
            dw[1, 1, a, b] += 1
            dw[2, 2, a, b] += 1
    J = np.zeros_like(out)
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            J[a,b] += dw[0, 0, a, b] * (dw[1, 1, a, b] * dw[2, 2, a, b] -
                                        dw[1, 2, a, b] * dw[2, 1, a, b]) -\
                      dw[0, 1, a, b] * (dw[1, 0, a, b] * dw[2, 2, a, b] -
                                        dw[1, 2, a, b] * dw[2, 0, a, b]) +\
                      dw[0, 2, a, b] * (dw[1, 0,a ,b] * dw[2, 1, a, b] -
                                        dw[1, 1, a, b] * dw[2, 0, a, b])
    Finv = np.zeros_like(dw)
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            Finv[0, 0, a, b] += (-dw[1, 2, a, b] * dw[2, 1, a, b] +
                                dw[1, 1, a, b] * dw[2, 2, a, b]) / J[a, b]
            Finv[1, 0, a, b] += (dw[1, 2, a, b] * dw[2, 0, a, b] -
                                dw[1, 0, a, b] * dw[2, 2, a, b]) / J[a, b]
            Finv[2, 0, a, b] += (-dw[1, 1, a, b] * dw[2, 0, a, b] +
                                dw[1, 0, a, b] * dw[2, 1, a, b]) / J[a, b]
            Finv[0, 1, a, b] += (dw[0, 2, a, b] * dw[2, 1, a, b] -
                                dw[0, 1, a, b] * dw[2, 2, a, b]) / J[a, b]
            Finv[1, 1, a, b] += (-dw[0, 2, a, b] * dw[2, 0, a, b] +
                                dw[0, 0, a, b] * dw[2, 2, a, b]) / J[a, b]
            Finv[2, 1, a, b] += (dw[0, 1, a, b] * dw[2, 0, a, b] -
                                dw[0, 0, a, b] * dw[2, 1, a, b]) / J[a, b]
            Finv[0, 2, a, b] += (-dw[0, 2, a, b] * dw[1, 1, a, b] +
                                dw[0, 1, a, b] * dw[1, 2, a, b]) / J[a, b]
            Finv[1, 2, a, b] += (dw[0, 2, a, b] * dw[1, 0, a, b] -
                                dw[0, 0, a, b] * dw[1, 2, a, b]) / J[a, b]
            Finv[2, 2, a, b] += (-dw[0, 1, a, b] * dw[1, 0, a, b] +
                                dw[0, 0, a, b] * dw[1, 1, a, b]) / J[a, b]
    for i in range(du.shape[0]):
        for j in range(du.shape[0]):
            for k in range(du.shape[0]):
                for l in range(du.shape[0]):
                    for a in range(J.shape[0]):
                        for b in range(J.shape[1]):
                            out[a, b] += ((mu + lmbda * J[a, b] * (J[a, b] - 1)) * Finv[i, j, a, b] * Finv[k, l, a, b] +\
                                          lmbda * (2 * J[a, b] - 1) * J[a, b] * Finv[j,i,a,b] * Finv[l,k,a,b])\
                                * du[i, j, a, b] * dv[k, l, a, b]
    return out

@BilinearForm
def jac2(u, v, w):
    return numbajac(*(grad(u), grad(v), grad(w['w'])))

I suppose it will be even faster if entire asm was JIT'ed, integration and everything.

gdmcbain commented 4 years ago

@ahojukka5 is correct to point out that one should always profile before attempting to optimize, but there is another aspect to the general issue of ‘parallelizing the assembly over the elements’. Although optimization seems to have been in mind here, besides the direct use of scikit-fem as a finite element assembler, another of its primary purposes is ‘a focus on computational experimentation’ (paper.md) and much contemporary research in finite element methodology does concern parallelization, so it might be worth getting some technniques and idioms in place for parallelizing the assembly over the elements, regardless of whether it's faster or slower.

For that, I think dask.bag looks good. Very nice API and easy to install everywhere. Probably reasonably performant too?

Can we combine these resulting matrices so that it actually saves time? I suppose the correct place to do this is before a call to any scipy.sparse routines.

Actually following on from above, I'm thinking about not combining the resulting matrices at all, but leaving them as separate operators in a domain decomposition setting. i have to attend to other matters now, but do hope to return to this, probably using dask.bag eventually.

On speeding up the assembly of hyperelastic forms #439, I'm puzzled that using NumPy to invert the local matrices was so slow. I did encounter this previously in looking at the inverse mapping of quadrilateral or hexahedral elements but didn't understand it then and was also puzzled then.

gdmcbain commented 4 years ago

Actually following on from above, I'm thinking about not combining the resulting matrices at all, but leaving them as separate operators in a domain decomposition setting.

I had been thinking about this for a while but a recent additional motivation came from Knoll & Keyes (2004, §3.2 ‘Newton–Krylov–Schwarz’):

Newton–Krylov–Schwarz (NKS) is a preconditioned Jacobian-free Newton–Krylov method in which the action of the preconditioner is composed from those of preconditioners defined on individual geometric subdomains.

bhaveshshrimali commented 4 years ago

I suppose it will be even faster if entire asm was JIT'ed, integration and everything.

@kinnala If you have time, could you expand on this? I suppose this is more involved than simply decorating with jit if I understand correctly. I could experiment with this option next week as I would have some free time.

kinnala commented 4 years ago

I actually tried this and the gain was insignificant, something like 20% max. improvement to running time. The idea would have been to inline and JIT all loops starting from this: https://github.com/kinnala/scikit-fem/blob/master/skfem/assembly/form/bilinear_form.py#L63 But it requires some major refactoring of different pieces and the gain was too small for the added complexity.

bhaveshshrimali commented 4 years ago

I see. Thanks a lot!! It does seem that JIT-ing the form is pretty much it as far as speeding up the code, right?

kinnala commented 4 years ago

Yes I think so, unless you want to try what's suggested in the title of the issue and the first post, i.e. parallelize over elements. That should end up being multiple times faster if done properly.

bhaveshshrimali commented 4 years ago

I see. I will try experimenting with dask and explore how the individual matrices could be combined and so on..

kinnala commented 4 years ago

My guess is that if parallelization over elements is to be performed, combining the results should be done before a call to _assemble_scipy_matrix here, e.g., by initializing data, rows and cols so that they can hold the entire matrix and then doing the loops before _assemble_scipy_matrix only for a subset of elements per thread.

adtzlr commented 3 years ago

I found out that with joblib.Parallel the assemble function can be executed in parallel with little modifications to the code. I used this technique in my own private fe-code (though a little bit cleaner) and it works quite okay there. The only drawback is we can't use the default loky backend but with threading everything is ok.

Just modify the file skfem/assembly/form/bilinear_form.py and replace

        # loop over the indices of local stiffness matrix
        for j in range(ubasis.Nbfun):
            for i in range(vbasis.Nbfun):
                ixs = slice(nt * (vbasis.Nbfun * j + i),
                            nt * (vbasis.Nbfun * j + i + 1))
                rows[ixs] = vbasis.element_dofs[i]
                cols[ixs] = ubasis.element_dofs[j]
                data[ixs] = self._kernel(
                    ubasis.basis[j],
                    vbasis.basis[i],
                    wdict,
                    dx,
                )

with

        from joblib import Parallel, delayed

        ixs_list = []
        bij_list = []

        # loop over the indices of local stiffness matrix (pre-loop)
        for j in range(ubasis.Nbfun):
            for i in range(vbasis.Nbfun):

                bij_list.append([ubasis.basis[j], vbasis.basis[i]])

                ixs = slice(nt * (vbasis.Nbfun * j + i),
                            nt * (vbasis.Nbfun * j + i + 1))

                ixs_list.append(ixs)

                rows[ixs] = vbasis.element_dofs[i]
                cols[ixs] = ubasis.element_dofs[j]

        data_list = Parallel(n_jobs=-1, prefer="threads")(
            delayed(self._kernel)(*bij, wdict, dx) for bij in bij_list)

        for ixs, d in zip(ixs_list, data_list):
            data[ixs] = d

I get about a 2x speed-up of my three-field hyperelasticity example from https://github.com/kinnala/scikit-fem/issues/616 in combination with a 3x refined hex-mesh.

What do you think about that?

kinnala commented 3 years ago

Cool, I think we used to have something like that but the speed improvement was visible only if GIL was circumvented, e.g., using numba nogil flag. I'll try it out with different elements and problem sizes at some point and report back.

kinnala commented 3 years ago

Here it is https://github.com/kinnala/scikit-fem/blob/666a259a434e07f0852a10483bf3a1a913ff1b89/skfem/assembly/form/bilinear_form.py#L76

kinnala commented 3 years ago

Is it the same thing here? https://joblib.readthedocs.io/en/latest/parallel.html#thread-based-parallelism-vs-process-based-parallelism

kinnala commented 3 years ago

We originally removed the implementation using threading because any error messages inside the forms were really complicated to understand due to threading adding lots of noise.

Edit: But I think we could add now an optional flag

@BilinearForm(parallel_kernel=True)
def bilinf(u, v, w):
    pass

to enable this kind of behaviour.

adtzlr commented 3 years ago

Yes, as far as I know joblib threading is the same as your original threading implementation. Optional flag sounds good! Errors could be checked without parallelization and then the flag could be turned on. I played with joblib some time ago and the best speed up can be achieved only by using the default backend. But things are getting complicated if the code uses custom classes (pickle errors)...

kinnala commented 3 years ago

Yes, as far as I know joblib threading is the same as your original threading implementation. Optional flag sounds good! Errors could be checked without parallelization and then the flag could be turned on. I played with joblib some time ago and the best speed up can be achieved only by using the default backend. But things are getting complicated if the code uses custom classes (pickle errors)...

Started this work in #625, check it out if you have time.

bhaveshshrimali commented 3 years ago

Thanks @kinnala @adtzlr for the work. Just happen to have gotten back to this.

For the example in #439, this is what I get for the assembly times on my laptop (haven't done a detailed timing analysis). A ~2x speed up is indeed nice to have.

import numpy as np
from skfem.helpers import grad, dot
from numba import jit
from skfem import *

@jit(nopython=True, nogil=True)
def vdet(F):
    J = np.zeros_like(F[0, 0])
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            J[a, b] += (
                F[0, 0, a, b]
                * (F[1, 1, a, b] * F[2, 2, a, b] - F[1, 2, a, b] * F[2, 1, a, b])
                - F[0, 1, a, b]
                * (F[1, 0, a, b] * F[2, 2, a, b] - F[1, 2, a, b] * F[2, 0, a, b])
                + F[0, 2, a, b]
                * (F[1, 0, a, b] * F[2, 1, a, b] - F[1, 1, a, b] * F[2, 0, a, b])
            )
    return J

@jit(nopython=True, nogil=True)
def vinv(F):
    J = vdet(F)
    Finv = np.zeros_like(F)
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            Finv[0, 0, a, b] += (
                -F[1, 2, a, b] * F[2, 1, a, b] + F[1, 1, a, b] * F[2, 2, a, b]
            ) / J[a, b]
            Finv[1, 0, a, b] += (
                F[1, 2, a, b] * F[2, 0, a, b] - F[1, 0, a, b] * F[2, 2, a, b]
            ) / J[a, b]
            Finv[2, 0, a, b] += (
                -F[1, 1, a, b] * F[2, 0, a, b] + F[1, 0, a, b] * F[2, 1, a, b]
            ) / J[a, b]
            Finv[0, 1, a, b] += (
                F[0, 2, a, b] * F[2, 1, a, b] - F[0, 1, a, b] * F[2, 2, a, b]
            ) / J[a, b]
            Finv[1, 1, a, b] += (
                -F[0, 2, a, b] * F[2, 0, a, b] + F[0, 0, a, b] * F[2, 2, a, b]
            ) / J[a, b]
            Finv[2, 1, a, b] += (
                F[0, 1, a, b] * F[2, 0, a, b] - F[0, 0, a, b] * F[2, 1, a, b]
            ) / J[a, b]
            Finv[0, 2, a, b] += (
                -F[0, 2, a, b] * F[1, 1, a, b] + F[0, 1, a, b] * F[1, 2, a, b]
            ) / J[a, b]
            Finv[1, 2, a, b] += (
                F[0, 2, a, b] * F[1, 0, a, b] - F[0, 0, a, b] * F[1, 2, a, b]
            ) / J[a, b]
            Finv[2, 2, a, b] += (
                -F[0, 1, a, b] * F[1, 0, a, b] + F[0, 0, a, b] * F[1, 1, a, b]
            ) / J[a, b]
    return Finv

@jit(nopython=True, nogil=True)
def numbares(dv, dw):
    out = np.zeros_like(dv[0, 0])
    F = np.zeros_like(dw)
    for a in range(dw.shape[2]):
        for b in range(dw.shape[3]):
            F[0, 0, a, b] += 1.0
            F[1, 1, a, b] += 1.0
            F[2, 2, a, b] += 1.0

    F += dw
    J = vdet(F)
    Finv = vinv(F)

    for i in range(dv.shape[0]):
        for j in range(dv.shape[0]):
            for a in range(J.shape[0]):
                for b in range(J.shape[1]):
                    out[a, b] += (
                        mu * F[i, j, a, b]
                        - mu * Finv[j, i, a, b]
                        + lmbda * J[a, b] * (J[a, b] - 1.0) * Finv[j, i, a, b]
                    ) * dv[i, j, a, b]

    return out

@jit(nopython=True, nogil=True)
def numbajac(du, dv, dw):
    out = np.zeros_like(du[0, 0])
    F = np.zeros_like(dw)
    for a in range(dw.shape[2]):
        for b in range(dw.shape[3]):
            F[0, 0, a, b] += 1.0
            F[1, 1, a, b] += 1.0
            F[2, 2, a, b] += 1.0

    F += dw
    J = vdet(F)
    Finv = vinv(F)

    kron = np.eye(dw.shape[0])
    for i in range(du.shape[0]):
        for j in range(du.shape[0]):
            for k in range(du.shape[0]):
                for l in range(du.shape[0]):
                    for a in range(J.shape[0]):
                        for b in range(J.shape[1]):
                            out[a, b] += (
                                (
                                    mu * kron[i, k] * kron[j, l]
                                    + (mu - lmbda * J[a, b] * (J[a, b] - 1))
                                    * Finv[j, k, a, b]
                                    * Finv[l, i, a, b]
                                    + lmbda
                                    * (2 * J[a, b] - 1)
                                    * J[a, b]
                                    * Finv[j, i, a, b]
                                    * Finv[l, k, a, b]
                                )
                                * du[i, j, a, b]
                                * dv[k, l, a, b]
                            )
    return out

@jit(nopython=True, nogil=True)
def numbaEnergy(dw):
    out = np.zeros_like(dw[0, 0])
    F = np.zeros_like(dw)
    for a in range(dw.shape[2]):
        for b in range(dw.shape[3]):
            F[0, 0, a, b] += 1.0
            F[1, 1, a, b] += 1.0
            F[2, 2, a, b] += 1.0

    F += dw
    J = vdet(F)
    I1 = np.zeros_like(dw[0, 0])
    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            for i in range(dw.shape[0]):
                for j in range(dw.shape[1]):
                    I1[a, b] += F[i, j, a, b] * F[i, j, a, b]

    for a in range(J.shape[0]):
        for b in range(J.shape[1]):
            out[a, b] = (
                mu / 2.0 * (I1[a, b] - 3.0)
                - mu * np.log(J[a, b])
                + lmbda / 2.0 * (J[a, b] - 1.0) ** 2
            )
    return out

mesh = MeshTet().refined(4)
elem = ElementTetP1()
uelem = ElementVectorH1(elem)
iBasis = InteriorBasis(mesh, uelem, intorder=2)
fBasis = FacetBasis(mesh, uelem, intorder=2)
u = np.zeros(iBasis.N)

E, nu = 10.0, 0.3
mu = E / 2 / (1 + nu)
lmbda = 2 * mu * nu / (1 - 2 * nu)

@LinearForm(nthreads=17)
def rhs(v, w):
    return -dot(bodyForce, v)

@LinearForm(nthreads=17)
def rhsSurf(v, w):
    return -dot(surfaceTraction, v) * (restBoundary(w.x))

@BilinearForm(nthreads=17)
def jac2(u, v, w):
    return numbajac(*(grad(u), grad(v), grad(w["w"])))

@LinearForm(nthreads=17)
def res2(v, w):
    return numbares(*(grad(v), grad(w["w"])))

@Functional(nthreads=17)
def energy(w):
    return numbaEnergy(grad(w["w"]))

w = iBasis.interpolate(u)
jac2.assemble(iBasis, w=w)

Timings

6.78 s ± 151 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) # nthreads = 0
3.49 s ± 305 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) # nthreads = 3
3.03 s ± 197 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) # nthreads = 9
3.18 s ± 283 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) # nthreads  = 17