jax-md / jax-md

Differentiable, Hardware Accelerated, Molecular Dynamics
Apache License 2.0
1.16k stars 189 forks source link

Don't pass around cell_list_fn? #134

Closed MaxiLechner closed 2 years ago

MaxiLechner commented 3 years ago

For the implicit differentiation stuff I am working on right now I would like to use jax.closure_convert as this makes for a much better user interface. One problem I have with this approach is that jax.closure_convert can't handle different cell_list_fn as their hashes do not agree (not completely sure about the cause here). One way to sidestep this whole issue is to simply remove the cell_list_fn from NeighborList and store it inside neighbor_list. What do you think?

sschoenholz commented 3 years ago

Good question. Do you think you can put together an example where jax.closure_convert fails with different cell_list_fn? I'd love to see what the specific usage is that's causing problems to try to solve it. I'm a little surprised that jax functions interact at all with cell_list_fn since it gets stripped from the PyTree.

At first glance, I am a little hesitant to store cell_list_fn inside neighbor_list although I am open to the possibility. The reason why I am a bit worries is because neighbor lists with different capacities will have different cell_list_fns. For example, consider the following use case:

neighbor_fn = neighbor_list(box_size, cutoff)
small_nbrs = neighbor_fn(R)
large_nbrs = neighbor_fn(R, extra_capacity=12)

# Do some simulating...

small_nbrs = neighbor_fn(R, small_nbrs)
large_nbrs = neighbor_fn(R, large_nbrs)

Of course neighbor_fn must use a different cell_list_fn for small_nbrs and large_nbrs. Thus, at the very least neighbor_fn cannot store a single cell_list_fn. Of course, this isn't such a big deal, it is possible that we could store a cache of cell_list_fn in neighbor_fn whose key is the capacity of the cell list. This is certainly a solution we can go to if there isn't another way to solve your issue. However, it will make the memory usage a little bit more opaque. E.g. if I delete small_nbrs, the cell_list_fn corresponding to small_nbrs will not be deleted and so it also won't be evicted from JAX's function cache. This might not make a big difference, but I could imagine it causing a memory leak for long running simulations that build neighbor lists of many sizes during their run.

In any case, let's definitely solve the problem you're having one way or another. If that means storing the cell_list_fn in the neighbor_fn then we can do that, but I'd like to see if there is another solution first!

MaxiLechner commented 3 years ago

Interesting, I haven't thought of that use case.

Here's a minimal example notebook. https://colab.research.google.com/drive/1csmZ2mFHED_HrkhIdhhA833Ohqs4_A6_?usp=sharing

While constructing a MWE I stumbled upon a second much bigger problem with closure_convert. Namely it cannot handle neighbor_lists of different sizes.

I assume you are not familiar with the inner workings of closure_convert? Then I'll try to build a simpler example and open up an issue with jax.

If you have any questions we can also talk in person. Just hit up @cpgoodri in that case.

MaxiLechner commented 3 years ago

So I have figured out how to sidestep the issues caused by closure_convert by making use of the dynamic_kwargs parameter of the force_fn. Here's the corresponding notebook https://colab.research.google.com/drive/1KtkABTZJR7PE5kr7Z-8diJccqVZzrah4?usp=sharing.

There is just one thing that one has to keep in mind when doing this in order to avoid leaked tracers. Namely it is necessary to partially apply the parameter which is differentiated to the force_fn that is passed to the solver (which computes the energy minimum).

e.g this works:

solver = lambda x, params: minimize_nl(neighbor_fn, functools.partial(force_fn,sigma=params), x, shift) 
R_final, nbrs = root_solve(force_fn,R_init,sigma,solver)

and this doesn't:

solver = lambda x, params: minimize_nl(neighbor_fn, force_fn, x, shift) 
R_final, nbrs = root_solve(force_fn,R_init,sigma,solver)

I'm using neighbor_lists here but the same issue shows up even when you are not using them. When using neighbor_lists you additionally get a memory leak for free ;) . I'm am reasonably sure that the memory leak only shows up when you are also getting this leaked tracer. It's possible to get rid of those leaked tracers if one replaces the return statement of root_solve.

return jax.tree_map(jax.lax.stop_gradient, solver(init_xs, params))

instead of

return solver(init_xs, params)

but this doesn't get rid of the memory leak.

While this trick seems to work it's quite easy to simply forget about it. Do you have an idea about how one could improve upon this?

One idea I came up with is to add a flag like only_dynamic_kwargs to smap.pair and co which adds lax.stop_gradient calls to the static_kwargs if it is set to True and further crashes if it is not called with a dynamic_kwargs. e.g.

energy_fn = energy.soft_sphere_pair(displacement,sigma=sigma,only_dynamic_kwargs=True)
energy_fn(x,sigma=sigma)

would work but

energy_fn = energy.soft_sphere_pair(displacement,sigma=sigma,only_dynamic_kwargs=True)
energy_fn(x)

would crash. There are of course more kwargs than just sigma so I feel like this isn't super user friendly.

sschoenholz commented 3 years ago

Hey Maxi,

Thanks for looking into this so closely. I'm sorry that I'm so unfamiliar with convert_closure that I'm not sure I can be of too much assistance. If there is a memory leak, then a reference to arrays must be retained somewhere. I guess I can think of a few places where that might be happening. I know that JAX has a fairly large cache of JIT functions that it populates. If a function is in the JIT cache and contain references to objects via the closure then those objects will not be deleted. One way to test whether this was happening would be to run using a very small system and see whether the program still OOMs or whether it reaches a steady state memory usage once the JIT cache is full and functions begin to get evicted.

One question, also, does this still happen if you disable the use of cell lists in the neighbor list function?

One thing that might be very helpful is if it were possible to make a reproducer of the leak that doesn't contain any jax md code. This would likely involve making a test case that does the same thing that neighbor lists do (aka retain a cell_list like function) and then see if you can still get a memory leak. If so, it would be great to post this against the JAX issue tracker. Seeing a clean example would definitely help me to understand where exactly the friction between JAX and JAX MD was coming in.

I can likely make the repro, but I won't have time until next week at the earliest. If you do have time and feel comfortable doing it, that would be super helpful.

Best, Sam

MaxiLechner commented 3 years ago

Hi Sam,

I am writing this right before I'm getting dinner. I might come back to what I wrote below on Monday.

I don't think I've made myself very clear but the memory leak and closure_convert are sort of orthogonal. I should have probably closed this issue and opened up a new one instead. Let me take a step back and give you some background.

I'm trying too implement implicit differentiation. Say I have a force_fn named f and a function solver which, given initial conditions, computes the argmin z of the energy of the system. Then I'd like to compute gradients of some function g grad(g)(z). This works, but as you know the memory requirements grow too fast to be able to make this work for anything but tiny systems. A better idea than to do plain reverse-mode autodiff is to make use of the implicit function theorem which roughly states that in order to compute derivatives of g(z) at the minimum it does not matter how you actually got there. This makes it possible to introduce a function root_solve(f,init_params,solver) that computes the energy minimum for the forward pass and memory efficient gradients for the backwards pass by using the custom derivative rules given by jax.custom_vjp. i.e. this backward pass only involves f.

It is not possible to define custom gradients with respect to some implicit arguments, so for the backwards pass I need a function f_hat(R,params) that takes the parameters I want to differentiate with respect to as an explicit argument. As an example, when we are starting from something like this

neighbor_fn, energy_fn = energy.soft_sphere_pair(displacement,sigma=sigma)
force_fn = quantity.force(energy_fn)

and we want to define custom gradients for my function root_solve then there are 2 ways forward. Let's further assume I want to compute derivatives with respect to sigma.

  1. I pull out sigma from the force_fn using closure_convert within root_solve, this converted function then is my f_hat. This way root_solve does not have to depend on sigma.

    z = root_solve_implicit(force_fn,R_init,solver)

    The issue with that is, that it simply doesn't work together with neighbor_lists. closure_convert does some form of type specialization so it isn't possible to change the size of the neighbor_list and further it also can't handle function arguments. It might be possible to generalize closure_convert but I don't think I could do this myself.

  2. I make use of the dynamic_kwargs. e.g. I can always call force_fn(R,sigma=sigma) with a potentially different sigma, so my function f_hat is just f itself. This way I have to pass the argument I want differentiate to my root_solve function.

    z = root_solve_explicit(force_fn,R_init,params,solver)

    The issue here is that jax finds a leaked tracer coming from the static_kwargs when defining the energy_fn. This can be avoided, which is what I tried to sketch with my last comment.

If you have any more questions just hit me up.

Best, Maxi

sschoenholz commented 2 years ago

Hi Maxi,

As of b134155650043079e6346742678cad6b40afbf1f (released in v0.1.22) the neighbor list no longer carries around the cell list function. Let me know if you have any more comments or issues!

All the best, Sam

MaxiLechner commented 2 years ago

Since I'm not using closure_convert anymore for the implicit differentiation stuff I don't have an opinion any more on carrying around the cell list function. Hope I didn't make you work for nothing.

sschoenholz commented 2 years ago

Nope not at all, it actually came about naturally from a refactor of the cell list code to make it look more like neighbor lists.