Closed SNMS95 closed 1 year ago
It'd be better if you also prepare some small examples that actually demand these proposed changes. Thanks!
Sure! Just to give a heads up.
It is related to making use of hessian-vector products (with forward over reverse autodiff). I believe that easy access to HVPs is an important capability over traditional hand-written sensitivity analysis. They would be useful in optimization, analysis and in solving!
Plus, once we delve into more of ML, e.g. meta-learning of some sort, being able to differentiate through a optimization becomes important (as done in JAXOpt)
PETSc solver seems to be broken. The following code,
solver(problem, linear=True, use_petsc=True)
, results in an error
TypeError: Cannot cast array data from dtype('int64') to dtype('int32') according to the rule 'safe'
Hi,
Ideally, the solvers should be chosen from JAX to have easy access to all of JAX's transformations. However, we know that external libraries like PETSc have very efficient and diverse solvers, which can be treated as black-box solvers in JAX. So, we would like to use these solvers but at the same time maintain as much of JAX's capabilities as possible.
Currently, the solver part is implemented (
ad_wrapper
function insolver.py
) by specifying thecustom_vjp
tag, which overrides the default traced auto-diff. But, this does not prevent JAX from tracing into the function. Since this function uses external libraries, this would result in errors if using JAX-FEM inside bigger pipelines. Further, providing only thevjp
rule prevents access to forward autodiff as well.As of now, there are two ways to handle this:
custom_jvp
+pure_callback
would ensure that the entire package is consistent.I suggest we discuss this possibility during our meeting.