deepmodeling / jax-fem

Differentiable Finite Element Method with JAX
GNU General Public License v3.0
290 stars 45 forks source link

Multi-step loading + auto-diff #5

Open SNMS95 opened 1 year ago

SNMS95 commented 1 year ago
SNMS95 commented 1 year ago
def non_trace_force_stepping_fn(design, force_magnitues):
    u_init=0.0
    for i in range(n_steps):
        force_magnitude = force_magnitudes[i]
        # Update value function
        problem_bcs = available_problem_bcs.ProblemBCsDatabase.get_problem('cantilever_2D', box_domain,
                                                                            force_magnitude=force_magnitude)
        fe_problem.neumann_bc_info = problem_bcs.neumann_bc_info

        # Solve the forward problem
        fwd_pred = ad_wrapper(problem=fe_problem, linear=True, use_petsc=False,
                              u_init=u_init)
        u = fwd_pred(design)
        u_init = u
    return u, fwd_pred

def pipeline_fn(params, state_of_params={}):
    design, state_of_params = parametrizer_fn_with_x(params=params,
                                              state=state_of_params)
    design = cone_filter_mapping(design)
    u_final, fwd_pred = non_trace_force_stepping_fn(jax.lax.stop_gradient(design), force_magnitudes)
    fe_solution = fwd_pred([design], u_init=u_final)
    objective_val = objective_fn(fe_solution)
    constraint_val = constraint_fn(design)
    return {
        "fe_solution": fe_solution,
        "design": design,
        "objective": objective_val,
        "constraint": constraint_val,
        "nn_state": state_of_params,
         }

This seems like an easy solution. Changes to be made:

  1. Adjust ad_wrapper to accept u_init_guess [Modify JAX-FEM]
  2. Expose force_magnitude in problem_bcs [TO-JAX]

@acse-itk22

h-vijayakumaran commented 1 year ago

I think the jax-am example for plasticity would be a good starting point