deepmodeling / jax-fem

Differentiable Finite Element Method with JAX
GNU General Public License v3.0
287 stars 45 forks source link

Changing mesh coordinate after defining problem #43

Open Simhano opened 3 hours ago

Simhano commented 3 hours ago

Hi, JAX-FEM community,

I have a question about changing mesh coordinates to optimize geometry.

I tried to get a derivative with respect to the initial configuration.

It is my set_params:

    def set_params(self, params):
        self.fes[0].points = params[0]

Then I defined the problem which is almost similar to the inverse demo in JAX-FEM. Here is how I set up my differentiation:

   problem = HyperElasticity_opt(mesh, vec=3, dim=3, ele_type=ele_type, dirichlet_bc_info=dirichlet_bc_info,
                             location_fns=location_fns, internal_pressure=internal_pressure_2)

   original_cood = mesh.points
   original_cood[non_fixed_nodes] = original_cood[non_fixed_nodes] * 1.1

    params = [original_cood]

    # Implicit differentiation wrapper
    fwd_pred = ad_wrapper(problem) 

    sol_list = fwd_pred(params)
    print(sol_list[0])

    def test_fn(sol_list):
        print(sol_list[0])
        return np.sum((sol_list[0] - u_sol_2)**2)

    def composed_fn(params):
        # print()
        return test_fn(fwd_pred(params))

    d_coord= jax.grad(composed_fn)(params)

My derivative (d_coord) was just zeros, and I could observe that my original coordinate had not changed since the problem was defined. (by observing sol_list[0] in test function) How can I change the original mesh coordinate after defining the problem? If this method does not work, is there any other way to do it?

Thank you for reading!

tianjuxue commented 2 hours ago

Unfortunately JAX-FEM does not support taking derivatives w.r.t. to mesh coordinates. But this type of problem has a very direct workaround. @xwpken Weipeng, could you share the paper that deals with this kind of problem? You need some trick in reformulating your problem (perhaps a smart definition for deformation gradient F in some alternative reference configuration).

Simhano commented 2 hours ago

Unfortunately JAX-FEM does not support taking derivatives w.r.t. to mesh coordinates. But this type of problem has a very direct workaround. @xwpken Weipeng, could you share the paper that deals with this kind of problem? You need some trick in reformulating your problem (perhaps a smart definition for deformation gradient F in some alternative reference configuration).

Thank you so much! I really appreciate your quick response! and I would sincerely appreciate it if Weipeng could share the paper that deals with this kind of problem!

Is JAX-FEM cannot take derivatives w.r.t. mesh coordinate because the mesh coordinate cannot be changed after the problem defined?

Thank you so much again!