Open mmorier-etu opened 2 months ago
Solving a linear system involves iterative steps, and is not as straightforward to be simply "batched". JAX-FEM uses the JAX version of scipy , not the original scipy, but I still doubt if it is even possible to perform batch operation for linear solvers.
Solving a linear system involves iterative steps, and is not as straightforward to be simply "batched". JAX-FEM uses the JAX version of scipy , not the original scipy, but I still doubt if it is even possible to perform batch operation for linear solvers.
In solver, the function get_A_fn is using scipy no ? I saw that we could use vmap on odeint so I thought it could work with JAX-FEM. But maybe I should avoid this approach if I want to use JAX-FEM ?
Here the error I got when I try to use vmap:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/user/mmorier/home/Documents/code/jax-fem-hybride/dataset_processing/data_para.py", line 118, in
Is there a reason that you are using onp instead of np ? I couldn't understand
It is to save GPU memory. You can use JAX-FEM to solve a problem with over 1 million DOFs on a single GPU. We need to save GPU memory so lots of places arrays are in onp.
For odeint I don't think they solve linear systems, it's more or less explicit updates.
It is to save GPU memory. You can use JAX-FEM to solve a problem with over 1 million DOFs on a single GPU. We need to save GPU memory so lots of places arrays are in onp.
For odeint I don't think they solve linear systems, it's more or less explicit updates.
ok thanks ! I understand now. So hypothetically, I can replace onp by np if I don't have a large number of DOFs right ? And for scipy in solver, get_A_fn : A_sp_scipy = scipy.sparse.csr_array( (onp.array(problem.V), (problem.I, problem.J)), shape=(problem.num_total_dofs_all_vars, problem.num_total_dofs_all_vars))
would it be a problem if I replace it by a Jax experimental sparse version ?
Ok I will look at it closer then.
Hi, I'm trying to create an hybrid model, that use your FEM-solver and a neural network. To do that, I need to solve the same equations with different parameters (e.g the heat diffusion, where the diffusion coefficient and ic are different for each data). I can't use vmap because the solver is using scipy and numpy, which isn't compatible. Do you think the solver can be adapted so it can managed batches or can be pass into vmap ? Thanks in advance for any idea !