tianjuxue / jax-am

Additive manufacturing simulation with JAX.
https://jax-am.readthedocs.io/en/latest/
GNU General Public License v3.0
265 stars 56 forks source link

Error in results for spatially varying material property #23

Open abhiawasthi1993 opened 1 year ago

abhiawasthi1993 commented 1 year ago

Hello, I am attempting to solve a 2D plane strain boundary value problem using a linear elastic material with Young's modulus that varies spatially, expressed as E = E(x,y) or E(nodes). I have used 'QUAD4' as element type. The basis functions utilized for the solution are assumed to support the Young's modulus. Below is the code I have employed to convert the values from nodes to the elements:

class Elasticity(FEM):
    def custom_init(self):
        """ Override base class method
        """
        self.flex_inds = np.arange(len(self.points))

    def get_tensor_map(self):
        def stress(u_grad, theta):
            nu = 0.3
            E = theta[0]
            epsilon = 0.5*(u_grad + u_grad.T)
            eps11 = epsilon[0,0]
            eps22 = epsilon[1,1]
            eps12 = epsilon[0,1]
            sig11 = E/((1 + nu)*(1 - 2*nu))*((1-nu)*eps11 + nu*eps22)
            sig22 = E/((1 + nu)*(1 - 2*nu))*(nu*eps11 + (1-nu)*eps22) 
            sig12 = E/((1 + nu)*2)*eps12
            sigma = np.array([[sig11, sig12],[sig12, sig22]])
            return sigma
        return stress

    def set_params(self, params):
        NCA = self.cells
        NCA = np.array(NCA)

        def compute_params(e):
            a = NCA[e,:]
            Ee = params[a]
            E = np.matmul(self.shape_vals, Ee)
            return np.transpose(E)

        result = jax.vmap(compute_params)(np.arange(self.num_cells))
        result = np.transpose(result, axes=(0,2,1))

        self.full_params = params
        self.internal_vars['laplace'] = [result]

When I used a uniform material parameter distribution (param = 0.01*np.ones((len(problem.flex_inds), 1))), the results matched those from a benchmark study. However, when I introduced a spatially varying (heterogeneous) distribution for the material parameter, the accuracy of the results decreased.

I attempted to find solutions in existing examples, but most of them assumed that the varying field had a constant value within each element, which does not apply to my case.

I would appreciate any guidance or suggestions to identify where I might have made errors in my implementation for handling the spatially varying material parameter and to improve the accuracy of the results.

tianjuxue commented 1 year ago

In JAX-FEM, you may have variable parameters up to each quadrature. Is this function helpful for your case?

https://github.com/tianjuxue/jax-am/blob/55115e49b4354ea220296f5f5d6f906854e39a7c/jax_am/fem/core.py#L712-L731

abhiawasthi1993 commented 1 year ago

Hello, thank you for recommending this function. It does perform the desired task; however, I noticed that the final displacement field differs from the benchmark results. For example, the maximum displacement in the y-direction is 0.447056 mm with JAX, while it is 0.447335 mm from the benchmark.

I have also cross-verified the cell values returned by the function, and they are identical to the benchmark values.

Below is a MWE. It would be great if you could review it and provide additional insights.

import numpy as onp
import jax
import jax.numpy as np
import os
import glob

from jax_am.fem.core import FEM
from jax_am.fem.solver import ad_wrapper
from jax_am.fem.utils import save_sol
from jax_am.fem.generate_mesh import Mesh, get_meshio_cell_type
from jax_am.fem.common import rectangle_mesh

class Elasticity(FEM):
    def custom_init(self):
        self.flex_inds = np.arange(len(self.points))

    def get_tensor_map(self):
        def stress(u_grad, theta):
            nu = 0.3
            E = theta[0]
            epsilon = 0.5*(u_grad + u_grad.T)
            eps11 = epsilon[0,0]
            eps22 = epsilon[1,1]
            eps12 = epsilon[0,1]
            sig11 = E/((1 + nu)*(1 - 2*nu))*((1-nu)*eps11 + nu*eps22)
            sig22 = E/((1 + nu)*(1 - 2*nu))*(nu*eps11 + (1-nu)*eps22) 
            sig12 = E/((1 + nu)*2)*eps12
            sigma = np.array([[sig11, sig12],[sig12, sig22]])
            return sigma
        return stress

    def set_params(self, params):
        result = self.convert_from_dof_to_quad(params)
        jax.debug.print("N: {}",self.shape_vals)
        self.full_params = params
        self.internal_vars['laplace'] = [result]

# SET DIRECTORY
data_dir = os.path.join(os.path.dirname(__file__), 'data')
files = glob.glob(os.path.join(data_dir, f'vtk/inverse/*'))
for f in files:
    os.remove(f)

# MESH INFORMATION
ele_type = 'QUAD4'
cell_type = get_meshio_cell_type(ele_type)
Lx, Ly = 120., 120.
meshio_mesh = rectangle_mesh(Nx=50, Ny=50, domain_x=Lx, domain_y=Ly)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type])

# BOUNDARY CONDITIONS
def fixed_xlocation(point):
    return np.isclose(point[0], 0., atol=1e-5)

def fixed_ylocation(point):
    return np.logical_and(np.isclose(point[0], 0, atol=1e-5), np.isclose(point[1], Ly, atol=1e-5))

def dbc_location(point):
    return np.isclose(point[0], Lx, atol=1e-5)

def fixed_val(point):
    return 0.

def dbc_val(point):
    return 1

dirichlet_bc_info = [[fixed_xlocation, fixed_ylocation, dbc_location],[0, 1, 0],[fixed_val, fixed_val, dbc_val]]

# PROBLEM DEFINITION
problem = Elasticity(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info)

fwd_pred = ad_wrapper(problem, linear=True, use_petsc=False)
def get_Evec(mesh):
    X = mesh.points[:,0]
    Y = mesh.points[:,1]

    t2 = 5 * np.exp(-(X-60)**2/(20**2) - (Y-60)**2/(20**2))
    E = 0.001*(10+t2)
    return E

Evec = get_Evec(mesh)
Evec = Evec[:,None]

sol = fwd_pred(Evec)

vtu_path = os.path.join(data_dir, f'vtk/sol_true.vtu')
save_sol(problem, np.hstack((sol, np.zeros((len(sol), 1)))), vtu_path, point_infos=[('theta', problem.full_params[:, 0])])
tianjuxue commented 1 year ago

What benchmark FEM software did you use? I think this is highly likely due to difference choices of linear solver as well as associated tolerance setting. For example, on my machine, I get the displacement to be 0.44705704246529027 with JAX solver, and 0.44626012034961243 with the PETSc solver (set use_petsc=True). They're already different!

abhiawasthi1993 commented 1 year ago

I am using a custom-written MATLAB FEM solver as the benchmark. You're right, the difference in results might be due to the type of solver I have used and the associated tolerance settings.

However, when I use a homogeneous property distribution, the results match up to five significant digits. Any thoughts?

tianjuxue commented 1 year ago

For homogeneous property distribution your final linear system will be better conditioned, so the solving will be easier and it matches for 5 digits for different solvers. But for inhomogeneous problems, the matrix will be worse conditioned, which will be harder for different solvers, hence they have different level of performance.

abhiawasthi1993 commented 1 year ago

That makes sense. Thank you.

Another question, if one would need to provide hessian information to the optimizer, how to create a wrapper for that? I tried the following:

def J_total(params):
    """ J(u(theta), theta) """
    sol = fwd_pred(params)
    dofs = sol.reshape(-1)
    obj_val = J_fn(dofs, params)
    return obj_val

def hess(f):
    return jax.jacfwd(jax.grad(f))

hess = hess(J_total)(x0)

but got the following error: _can't apply forward-mode autodiff (jvp) to a customvjp function. Got similar error while using jax.jacrev.

tianjuxue commented 1 year ago

Please allow us some time to implement this feature. @SNMS95 previously worked on this. There is an experimental version that you can play with

https://github.com/tianjuxue/jax-am/blob/814384eb522b839d1537a168e2cc07b34553c012/jax_am/fem/autodiff_utils.py#L38

Should be able to provide you with higher order derivatives.

abhiawasthi1993 commented 1 year ago

Hi @tianjuxue, I have been working on solving a 2D elasticity problem under harmonic loading with spatially dependent material parameters. I was able to assemble the stiffness and mass matrices successfully. While the mass matrix is correct, the values in stiffness matrix are coming out to be less than the desired values. I have compared the stiffness matrix with an in-house MATLAB FEM solver.

The values for the shape function and their gradients, JxW etc are consistent with the MATLAB based solver.

I also checked the values for stiffness matrix with homogeneous material properties and error still persists. I have attached a screenshot for the same.

stiff

tianjuxue commented 1 year ago

JAX-FEM is well tested and compared with other software, e.g., FEniCS for stiffness matrix value, particularly for linear elastic problems. When compared with Abaqus, we indeed found a difference in the stiffness matrix, but the final solutions are the same. I am pinning @jiachenguoNU here, who has experience over this issue.

abhiawasthi1993 commented 1 year ago

Thank you @tianjuxue, for the clarification. I will double-check my stiffness matrix and compare the same with FEniCS as well.

@jiachenguoNU: Any leads on this would be very helpful.

abhiawasthi1993 commented 11 months ago

Hi @tianjuxue, I was using the incorrect constitutive model, leading to a difference in the stiffness matrix. The values now match with that of the benchmark as well as FEniCS.

Next, I was trying to solve the same problem but with a nonlinear material model (Neo-Hookean). Initially, I successfully solved the 3D cylinder problem and validated the displacement field against a benchmark (COMSOL).

However, when I applied the same approach to a 2D problem, same as solved earlier with linear elasticity, the results were inaccurate. Although the solver converged, the displacement in the y-direction was way too large (max value: 14.58 mm).

Attached below is the MWE of the same:

import numpy as onp
import numpy.testing as onptest
import jax
import jax.numpy as np
import meshio
import os

from jax_am.fem.generate_mesh import Mesh, get_meshio_cell_type
from jax_am.fem.models import HyperElasticity
from jax_am.fem.common import rectangle_mesh
from jax_am.fem.solver import solver

ele_type = 'QUAD4'
cell_type = get_meshio_cell_type(ele_type)
Lx, Ly = 120., 120.
meshio_mesh = rectangle_mesh(Nx=50, Ny=50, domain_x=Lx, domain_y=Ly)
mesh = Mesh(meshio_mesh.points, meshio_mesh.cells_dict[cell_type], ele_type=ele_type)
data_dir = os.path.join(os.path.dirname(__file__), 'data')

def fixed_xlocation(point):
    return np.isclose(point[0], 0., atol=1e-5)

def fixed_ylocation(point):
    return np.logical_and(np.isclose(point[0], 0, atol=1e-5), np.isclose(point[1], Ly, atol=1e-5))

def dbc_location(point):
    return np.isclose(point[0], Lx, atol=1e-5)

def fixed_val(point):
    return 0.

def dbc_val(point):
    return 6

dirichlet_bc_info = [[fixed_xlocation, fixed_ylocation, dbc_location], [0, 1, 0],[fixed_val, fixed_val, dbc_val]]

problem = HyperElasticity(mesh, vec=2, dim=2, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info)
sol = solver(problem, linear=False)

Could you please check it once?