Open SNMS95 opened 1 year ago
def non_trace_force_stepping_fn(design, force_magnitues):
u_init=0.0
for i in range(n_steps):
force_magnitude = force_magnitudes[i]
# Update value function
problem_bcs = available_problem_bcs.ProblemBCsDatabase.get_problem('cantilever_2D', box_domain,
force_magnitude=force_magnitude)
fe_problem.neumann_bc_info = problem_bcs.neumann_bc_info
# Solve the forward problem
fwd_pred = ad_wrapper(problem=fe_problem, linear=True, use_petsc=False,
u_init=u_init)
u = fwd_pred(design)
u_init = u
return u, fwd_pred
def pipeline_fn(params, state_of_params={}):
design, state_of_params = parametrizer_fn_with_x(params=params,
state=state_of_params)
design = cone_filter_mapping(design)
u_final, fwd_pred = non_trace_force_stepping_fn(jax.lax.stop_gradient(design), force_magnitudes)
fe_solution = fwd_pred([design], u_init=u_final)
objective_val = objective_fn(fe_solution)
constraint_val = constraint_fn(design)
return {
"fe_solution": fe_solution,
"design": design,
"objective": objective_val,
"constraint": constraint_val,
"nn_state": state_of_params,
}
This seems like an easy solution. Changes to be made:
ad_wrapper
to accept u_init_guess
[Modify JAX-FEM]force_magnitude
in problem_bcs [TO-JAX]@acse-itk22
I think the jax-am example for plasticity would be a good starting point
ad_wrapper
performs implicit differentiation through a non-linear solve, we should allow either implicit differentiation through load steps!