Open f0uriest opened 2 weeks ago
I think we can simply get rid of the error by changing the number of iterations dtype to float like,
...
state = (
jnp.atleast_1d(jnp.asarray(guess)), # x
jnp.atleast_1d(resfun(guess)), # residual
0.0, # number of iterations
)
...
As long as we don't use the derivative of the number of iterations later in the code, I believe this shouldn't change the differentiation of root
.
That said, this is probably not how you want to implement it. A more proper way could be writing custom_jvp
for root
and setting the derivative of niter
to SymbolicZeros
, but this is more cumbersome.
Description
I have a basic root finder like this:
which returns both the root and the value of f at the root, and the number of steps taken. Previously this worked fine, with
has_aux=True
forcustom_root
. However, v0.4.34 seems to have changed something in the way tangents of non-differentiable values get propagated (#24262).Now running the following
gives the following:
I get the same error if I drop the aux output in
find_root_fun
and leave out thehas_aux
when callingjacfwd
. The only way I've found to avoid the error is to remove the aux from the innermostsolve
and sethas_aux=False
oncustom_root
Is this expected? I assumed having integer valued aux output was kind of the point of the
has_aux
option?System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34 jaxlib: 0.4.34 numpy: 1.24.4 python: 3.10.11 (main, May 16 2023, 00:28:57) [GCC 11.2.0] jax.devices (8 total, 8 local): [CpuDevice(id=0) CpuDevice(id=1) ... CpuDevice(id=6) CpuDevice(id=7)] process_count: 1 platform: uname_result(system='Linux', node='Discovery', release='5.15.0-119-generic', version='#129~20.04.1-Ubuntu SMP Wed Aug 7 13:07:13 UTC 2024', machine='x86_64')