Open h-vijayakumaran opened 4 months ago
It looks like the following is sufficient to reproduce
def f(x):
return jnp.prod(jnp.linalg.eigvalsh(x))
I traced the nan to this multiplication, where Fmat
has infinities, which turn into nans during matmul.
@hawkinsp is this WAI?
Description
This is probably a related one to #1383
When I compose a function with
eigvalsh
andsqrt
orlog
, I am encounteringnan
mainly with higher order gradients, like hessian. I have also done a "hardcoded" eigenvalue computation, which seems to give correct results.Here is minimal example
System info (python version, jaxlib version, accelerator, etc.)
I am using CPU