tianjuxue / jax-am

Additive manufacturing simulation with JAX.
https://jax-am.readthedocs.io/en/latest/
GNU General Public License v3.0
268 stars 56 forks source link

Improving autodiff capabilities of JAX-FEM module #15

Closed SNMS95 closed 1 year ago

SNMS95 commented 1 year ago

Hi,

Ideally, the solvers should be chosen from JAX to have easy access to all of JAX's transformations. However, we know that external libraries like PETSc have very efficient and diverse solvers, which can be treated as black-box solvers in JAX. So, we would like to use these solvers but at the same time maintain as much of JAX's capabilities as possible.

Currently, the solver part is implemented (ad_wrapper function in solver.py) by specifying the custom_vjp tag, which overrides the default traced auto-diff. But, this does not prevent JAX from tracing into the function. Since this function uses external libraries, this would result in errors if using JAX-FEM inside bigger pipelines. Further, providing only the vjp rule prevents access to forward autodiff as well.

As of now, there are two ways to handle this:

  1. Create a Primitive and assign individual transformation rules (Best for performance but probably very hard!) - We will have to do this wherever we choose external libraries.
  2. Use callbacks for all external calls (This causes those operations to run in CPU but allows access to all of JAX's transforms and is much easier and cleaner). Specifying a custom_jvp + pure_callback would ensure that the entire package is consistent.

I suggest we discuss this possibility during our meeting.

tianjuxue commented 1 year ago

It'd be better if you also prepare some small examples that actually demand these proposed changes. Thanks!

SNMS95 commented 1 year ago

Sure! Just to give a heads up.

It is related to making use of hessian-vector products (with forward over reverse autodiff). I believe that easy access to HVPs is an important capability over traditional hand-written sensitivity analysis. They would be useful in optimization, analysis and in solving!

Plus, once we delve into more of ML, e.g. meta-learning of some sort, being able to differentiate through a optimization becomes important (as done in JAXOpt)

SNMS95 commented 1 year ago

PETSc solver seems to be broken. The following code,

solver(problem, linear=True, use_petsc=True)

, results in an error

TypeError: Cannot cast array data from dtype('int64') to dtype('int32') according to the rule 'safe'

image