Open MaxiLechner opened 3 years ago
@froystig
I'm running into the same error using just jaxopt, without jax_md (but using haiku).
The issue I believe is that D
is defined in the scope of the solver and currently JAXopt doesn't support closure conversion. Instead D
should be an explicit argument of your solver
and explicit_force_fn
(note that when using custom_root(optimality_fun)(solver)
, optimality_fun
(explicit_force_fn
in your case) and solver
are assumed to have the same signature). If D
is an explicit argument of solver
, then you will be able to call R_final = decorated_solver(None,R_init, D=D)
and differentiation w.r.t. D
should work.
Copy-pasting the relevant code for context:
def run_implicit(D, key, N=128):
box_size = 4.5 * D
# box_size = lax.stop_gradient(box_size)
displacement, shift = space.periodic(box_size)
R_init = random.uniform(key, (N,2), minval=0.0, maxval=box_size, dtype=f64)
energy_fn = jit(energy.soft_sphere_pair(displacement,sigma=D))
force_fn = jit(quantity.force(energy_fn))
# wrap force_fn with a lax.stop_gradient to prevent a CustomVJPException
no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))
# make the dependence on the variables we want to differentiate explicit
explicit_force_fn = jit(lambda R, p: force_fn(R, sigma=p))
def solver(params, x):
# params are unused
del params
# need to use no_grad_force_fn!
return run_minimization_while(no_grad_force_fn, x, shift, dt_start=0.01, dt_max=0.05)[0]
# return jax.lax.stop_gradient(run_minimization_while(no_grad_force_fn, x, shift, dt_start=0.01, dt_max=0.05)[0])
decorated_solver = jaxopt.implicit_diff.custom_root(explicit_force_fn)(solver)
R_final = decorated_solver(None,R_init)
# Here we can just use our original energy_fn/force_fn
return energy_fn(R_final)
I'm working on implicit differentiation using both jaxopt and jax-md and I've ran into a
NotImplementedError
. I've figured out how to make the error go away but I don't understand why my fix is necessary in the first place.Here's a notebook that demonstrates the issue: https://colab.research.google.com/drive/1DttJc6yyXy0KZ_5c321-ksRQpGd0C-Bk?usp=sharing
The issue seems to be caused by the fact that the variable 'box_size' depends on the parameter that I want to differentiate, indeed when I call
lax.stop_gradient
on the variablebox_size
the error goes away and I get the same results as if I had used plain old ad. I'm not quite sure what's the underlying error because when I don't callstop_gradient
onbox_size
but instead on the output ofsolver
function then I get a tracer leak instead with the following error message:I can try to construct a simpler example if that helps you.