deepmodeling / jax-fem

Differentiable Finite Element Method with JAX
GNU General Public License v3.0
252 stars 41 forks source link

NaN encountered when disabling PETSc in topology opt example #27

Open BBBBBbruce opened 3 months ago

BBBBBbruce commented 3 months ago

Hi, I wonder why disabling PETSc in the in the demos/topology_optimisation example would lead to NaN problem.

The code to change here is fwd_pred = ad_wrapper(problem, linear=True, use_petsc=True, use_petsc_adjoint=False) use_petsc_adjoint can be either True or false, the results wont be affected; but when I disable use_petsc(which I guess the proragmming solve the linear systems with jax_solve) I have runtime error of NaN encountered. This happens after a few iterations

I tried to set jax to 64 bit precision. But the error persists. I found that the PETSc solver is just for the forward, and in the custom backward pass, the PETSc is not used. So, why could jax_solve failed in such case?

tianjuxue commented 3 months ago

This is a situation we also encountered. I suspect that the linear system is quite stiff due to some element having very small modulus. This brings difficulties for linear solvers, so even if we call the same bicgstab method, the result of PETSc will differ from JAX. Sometimes it is even not able to converge.

I might be wrong, but this has also been a long standing issue for us.

BBBBBbruce commented 3 months ago

@tianjuxue Hi, I have a new discovery where I suspect it is the machine precision problem. briefly, I tried to force the execution on CPU, and I didnt encounter the NaN issue. Simply add one line as: with jax.default_device(jax.devices("cpu")[0]): vf = 0.5 optimizationParams = {'maxIters':51, 'movelimit':0.1} rho_ini = vf*np.ones((len(problem.fe.flex_inds), 1)) numConstraints = 1 optimize(problem.fe, rho_ini, optimizationParams, objectiveHandle, consHandle, numConstraints) print(f"As a reminder, compliance = {J_total(np.ones((len(problem.fe.flex_inds), 1)))} for full material") #the rest of the scipt

Also, I found that when running on GPU, jax_enable_x64 is not working. no matter it is True or False, the solver always persists with single float. I am not sure if this is a jax bug or the bug in the jax-fem framework.

tianjuxue commented 3 months ago

"I found that when running on GPU, jax_enable_x64 is not working"

This is a bit strange, are you talking about PETSc solver or JAX built-in solver?

BBBBBbruce commented 3 months ago

I disable the PETSc completely. so only jax based methods are using.

tianjuxue commented 3 months ago

This does not happen to us. Usually, if float32 is used, the JAX bicgstab solver will never converge, even for a simple problem.

BBBBBbruce commented 3 months ago

Hi, I am having another question regards the topology opt demos. I will reuse the notation in the readme in the following description.

My question is in the custom vjp(the implici_vjp() function), it seems evaluating the hessian of the linear elastic equation.

My understanding is: The optimisation problem we can write as J(u) and u = argmin_u (C(u,x)) C(u,x) is the linear elastic equation and u is solved as u_k+1 = u_k - [dC/du_k] ^{-1} C(u_k)

The function get_A_fn(problem, use_petsc) prepares the dC/du which is already the gradient

and my problem is in the backward pass, adjoint_linear_fn = get_vjp_contraint_fn_dofs(dofs) is supposed to computing the dC/du as well (acoording to the equation 21 in the paper). However, in the definition of get_vjp_contraint_fn_dofs(), it returns the vjp of A_fn(code snippets below). Is it trying to get d(dC/du)/du?

I verified the gradients with finite difference, which is correct. Can you see where I make the mistake?

    def get_vjp_contraint_fn_dofs(dofs):
        # Just a transpose of A_fn
        def adjoint_linear_fn(adjoint):
            primals, f_vjp = jax.vjp(A_fn, dofs)
            val, = f_vjp(adjoint)
            return val

        return adjoint_linear_fn