deepmodeling / jax-fem

Differentiable Finite Element Method with JAX
GNU General Public License v3.0
197 stars 33 forks source link

How to do parallel solver #25

Open mmorier-etu opened 2 months ago

mmorier-etu commented 2 months ago

Hi, I'm trying to create an hybrid model, that use your FEM-solver and a neural network. To do that, I need to solve the same equations with different parameters (e.g the heat diffusion, where the diffusion coefficient and ic are different for each data). I can't use vmap because the solver is using scipy and numpy, which isn't compatible. Do you think the solver can be adapted so it can managed batches or can be pass into vmap ? Thanks in advance for any idea !

tianjuxue commented 2 months ago

Solving a linear system involves iterative steps, and is not as straightforward to be simply "batched". JAX-FEM uses the JAX version of scipy , not the original scipy, but I still doubt if it is even possible to perform batch operation for linear solvers.

mmorier-etu commented 2 months ago

Solving a linear system involves iterative steps, and is not as straightforward to be simply "batched". JAX-FEM uses the JAX version of scipy , not the original scipy, but I still doubt if it is even possible to perform batch operation for linear solvers.

In solver, the function get_A_fn is using scipy no ? I saw that we could use vmap on odeint so I thought it could work with JAX-FEM. But maybe I should avoid this approach if I want to use JAX-FEM ?

Here the error I got when I try to use vmap:

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/user/mmorier/home/Documents/code/jax-fem-hybride/dataset_processing/data_para.py", line 118, in sols = solve_problem_vec(sol_dts, ds) File "/user/mmorier/home/Documents/code/jax-fem-hybride/dataset_processing/data_para.py", line 98, in solve_problem sol = solver(problem_source, linear=True, use_petsc=True)[0] File "/home/mmorier/Documents/code/jax-fem/jax_fem/solver.py", line 888, in solver return solver_row_elimination(problem, linear, precond, initial_guess, use_petsc, petsc_options) File "/home/mmorier/Documents/code/jax-fem/jax_fem/solver.py", line 404, in solver_row_elimination res_vec, A_fn = newton_update_helper(dofs) File "/home/mmorier/Documents/code/jax-fem/jax_fem/solver.py", line 395, in newton_update_helper res_list = problem.newton_update(sol_list) File "/home/mmorier/Documents/code/jax-fem/jax_fem/problem.py", line 414, in newton_update return self.compute_newton_vars(sol_list, self.internal_vars, self.internal_vars_surfaces) File "/home/mmorier/Documents/code/jax-fem/jax_fem/problem.py", line 400, in compute_newton_vars weak_form_flat, cells_jac_flat = self.split_and_compute_cell(cells_sol_flat, onp, True, internal_vars) File "/home/mmorier/Documents/code/jax-fem/jax_fem/common.py", line 95, in timeit_wrapper result = func(*args, *kwargs) File "/home/mmorier/Documents/code/jax-fem/jax_fem/problem.py", line 329, in split_and_compute_cell values = np_version.vstack(values) File "<__array_function__ internals>", line 200, in vstack File "/user/mmorier/home/.conda/envs/jax-fem-env/lib/python3.9/site-packages/numpy/core/shape_base.py", line 293, in vstack arrs = atleast_2d(tup) File "<__array_function__ internals>", line 200, in atleast_2d File "/user/mmorier/home/.conda/envs/jax-fem-env/lib/python3.9/site-packages/numpy/core/shape_base.py", line 121, in atleast_2d ary = asanyarray(ary) jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on traced array with shape float64[51,4]. This BatchTracer with object id 140145456435504 was created on line: /home/mmorier/Documents/code/jax-fem/jax_fem/problem.py:326 (split_and_compute_cell) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

Is there a reason that you are using onp instead of np ? I couldn't understand

tianjuxue commented 2 months ago

It is to save GPU memory. You can use JAX-FEM to solve a problem with over 1 million DOFs on a single GPU. We need to save GPU memory so lots of places arrays are in onp.

For odeint I don't think they solve linear systems, it's more or less explicit updates.

mmorier-etu commented 2 months ago

It is to save GPU memory. You can use JAX-FEM to solve a problem with over 1 million DOFs on a single GPU. We need to save GPU memory so lots of places arrays are in onp.

For odeint I don't think they solve linear systems, it's more or less explicit updates.

ok thanks ! I understand now. So hypothetically, I can replace onp by np if I don't have a large number of DOFs right ? And for scipy in solver, get_A_fn : A_sp_scipy = scipy.sparse.csr_array( (onp.array(problem.V), (problem.I, problem.J)), shape=(problem.num_total_dofs_all_vars, problem.num_total_dofs_all_vars))

would it be a problem if I replace it by a Jax experimental sparse version ?

Ok I will look at it closer then.