jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.57k stars 2.81k forks source link

NotImplementedError: Differentiation rule for 'custom_lin' not implemented #8557

Open MaxiLechner opened 3 years ago

MaxiLechner commented 3 years ago

I'm working on implicit differentiation using both jaxopt and jax-md and I've ran into a NotImplementedError. I've figured out how to make the error go away but I don't understand why my fix is necessary in the first place.

Here's a notebook that demonstrates the issue: https://colab.research.google.com/drive/1DttJc6yyXy0KZ_5c321-ksRQpGd0C-Bk?usp=sharing

The issue seems to be caused by the fact that the variable 'box_size' depends on the parameter that I want to differentiate, indeed when I call lax.stop_gradient on the variable box_size the error goes away and I get the same results as if I had used plain old ad. I'm not quite sure what's the underlying error because when I don't call stop_gradient on box_size but instead on the output of solver function then I get a tracer leak instead with the following error message:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape () and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.Detail: Different traces at same level: Traced<ShapedArray(float32[])>with<JVPTrace(level=2/3)>
  with primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=0/3)>
       tangent = Traced<ShapedArray(float32[]):JaxprTrace(level=1/3)>, JVPTrace(level=2/3)

I can try to construct a simpler example if that helps you.

zhangqiaorjc commented 3 years ago

@froystig

FerranAlet commented 3 years ago

I'm running into the same error using just jaxopt, without jax_md (but using haiku).

mblondel commented 2 years ago

The issue I believe is that D is defined in the scope of the solver and currently JAXopt doesn't support closure conversion. Instead D should be an explicit argument of your solver and explicit_force_fn (note that when using custom_root(optimality_fun)(solver), optimality_fun (explicit_force_fn in your case) and solver are assumed to have the same signature). If D is an explicit argument of solver, then you will be able to call R_final = decorated_solver(None,R_init, D=D) and differentiation w.r.t. D should work.

Copy-pasting the relevant code for context:

def run_implicit(D, key, N=128):
    box_size = 4.5 * D
    # box_size = lax.stop_gradient(box_size)

    displacement, shift = space.periodic(box_size) 

    R_init = random.uniform(key, (N,2), minval=0.0, maxval=box_size, dtype=f64) 

    energy_fn = jit(energy.soft_sphere_pair(displacement,sigma=D))

    force_fn = jit(quantity.force(energy_fn))

    # wrap force_fn with a lax.stop_gradient to prevent a CustomVJPException
    no_grad_force_fn = jit(lambda x: lax.stop_gradient(force_fn(x)))

    # make the dependence on the variables we want to differentiate explicit
    explicit_force_fn = jit(lambda R, p: force_fn(R, sigma=p))

    def solver(params, x):
        # params are unused
        del params
        # need to use no_grad_force_fn!
        return run_minimization_while(no_grad_force_fn, x, shift, dt_start=0.01, dt_max=0.05)[0]
        # return jax.lax.stop_gradient(run_minimization_while(no_grad_force_fn, x, shift, dt_start=0.01, dt_max=0.05)[0])

    decorated_solver = jaxopt.implicit_diff.custom_root(explicit_force_fn)(solver)

    R_final = decorated_solver(None,R_init)

    # Here we can just use our original energy_fn/force_fn
    return energy_fn(R_final)
mblondel commented 2 years ago

For instance, in this example, we differentiate w.r.t. l2reg. You can see that l2reg is an explicit argument of ridge_solver and of jax.grad(ridge_objective)).