Closed MaxiLechner closed 2 years ago
Good question. Do you think you can put together an example where jax.closure_convert
fails with different cell_list_fn
? I'd love to see what the specific usage is that's causing problems to try to solve it. I'm a little surprised that jax functions interact at all with cell_list_fn
since it gets stripped from the PyTree.
At first glance, I am a little hesitant to store cell_list_fn
inside neighbor_list
although I am open to the possibility. The reason why I am a bit worries is because neighbor lists with different capacities will have different cell_list_fns. For example, consider the following use case:
neighbor_fn = neighbor_list(box_size, cutoff)
small_nbrs = neighbor_fn(R)
large_nbrs = neighbor_fn(R, extra_capacity=12)
# Do some simulating...
small_nbrs = neighbor_fn(R, small_nbrs)
large_nbrs = neighbor_fn(R, large_nbrs)
Of course neighbor_fn
must use a different cell_list_fn
for small_nbrs
and large_nbrs
. Thus, at the very least neighbor_fn
cannot store a single cell_list_fn
. Of course, this isn't such a big deal, it is possible that we could store a cache of cell_list_fn
in neighbor_fn
whose key is the capacity of the cell list. This is certainly a solution we can go to if there isn't another way to solve your issue. However, it will make the memory usage a little bit more opaque. E.g. if I delete small_nbrs
, the cell_list_fn
corresponding to small_nbrs
will not be deleted and so it also won't be evicted from JAX's function cache. This might not make a big difference, but I could imagine it causing a memory leak for long running simulations that build neighbor lists of many sizes during their run.
In any case, let's definitely solve the problem you're having one way or another. If that means storing the cell_list_fn
in the neighbor_fn
then we can do that, but I'd like to see if there is another solution first!
Interesting, I haven't thought of that use case.
Here's a minimal example notebook. https://colab.research.google.com/drive/1csmZ2mFHED_HrkhIdhhA833Ohqs4_A6_?usp=sharing
While constructing a MWE I stumbled upon a second much bigger problem with closure_convert
.
Namely it cannot handle neighbor_lists of different sizes.
I assume you are not familiar with the inner workings of closure_convert
? Then I'll try to build a simpler example and open up an issue with jax.
If you have any questions we can also talk in person. Just hit up @cpgoodri in that case.
So I have figured out how to sidestep the issues caused by closure_convert by making use of the dynamic_kwargs
parameter of the force_fn
. Here's the corresponding notebook https://colab.research.google.com/drive/1KtkABTZJR7PE5kr7Z-8diJccqVZzrah4?usp=sharing.
There is just one thing that one has to keep in mind when doing this in order to avoid leaked tracers. Namely it is necessary to partially apply the parameter which is differentiated to the force_fn
that is passed to the solver
(which computes the energy minimum).
e.g this works:
solver = lambda x, params: minimize_nl(neighbor_fn, functools.partial(force_fn,sigma=params), x, shift)
R_final, nbrs = root_solve(force_fn,R_init,sigma,solver)
and this doesn't:
solver = lambda x, params: minimize_nl(neighbor_fn, force_fn, x, shift)
R_final, nbrs = root_solve(force_fn,R_init,sigma,solver)
I'm using neighbor_lists here but the same issue shows up even when you are not using them. When using neighbor_lists you additionally get a memory leak for free ;) . I'm am reasonably sure that the memory leak only shows up when you are also getting this leaked tracer.
It's possible to get rid of those leaked tracers if one replaces the return
statement of root_solve.
return jax.tree_map(jax.lax.stop_gradient, solver(init_xs, params))
instead of
return solver(init_xs, params)
but this doesn't get rid of the memory leak.
While this trick seems to work it's quite easy to simply forget about it. Do you have an idea about how one could improve upon this?
One idea I came up with is to add a flag like only_dynamic_kwargs
to smap.pair
and co which adds lax.stop_gradient
calls to the static_kwargs
if it is set to True
and further crashes if it is not called with a dynamic_kwargs
. e.g.
energy_fn = energy.soft_sphere_pair(displacement,sigma=sigma,only_dynamic_kwargs=True)
energy_fn(x,sigma=sigma)
would work but
energy_fn = energy.soft_sphere_pair(displacement,sigma=sigma,only_dynamic_kwargs=True)
energy_fn(x)
would crash. There are of course more kwargs
than just sigma
so I feel like this isn't super user friendly.
Hey Maxi,
Thanks for looking into this so closely. I'm sorry that I'm so unfamiliar with convert_closure
that I'm not sure I can be of too much assistance. If there is a memory leak, then a reference to arrays must be retained somewhere. I guess I can think of a few places where that might be happening. I know that JAX has a fairly large cache of JIT functions that it populates. If a function is in the JIT cache and contain references to objects via the closure then those objects will not be deleted. One way to test whether this was happening would be to run using a very small system and see whether the program still OOMs or whether it reaches a steady state memory usage once the JIT cache is full and functions begin to get evicted.
One question, also, does this still happen if you disable the use of cell lists in the neighbor list function?
One thing that might be very helpful is if it were possible to make a reproducer of the leak that doesn't contain any jax md code. This would likely involve making a test case that does the same thing that neighbor lists do (aka retain a cell_list
like function) and then see if you can still get a memory leak. If so, it would be great to post this against the JAX issue tracker. Seeing a clean example would definitely help me to understand where exactly the friction between JAX and JAX MD was coming in.
I can likely make the repro, but I won't have time until next week at the earliest. If you do have time and feel comfortable doing it, that would be super helpful.
Best, Sam
Hi Sam,
I am writing this right before I'm getting dinner. I might come back to what I wrote below on Monday.
I don't think I've made myself very clear but the memory leak and closure_convert
are sort of orthogonal. I should have probably closed this issue and opened up a new one instead. Let me take a step back and give you some background.
I'm trying too implement implicit differentiation. Say I have a force_fn
named f
and a function solver
which, given initial conditions, computes the argmin z
of the energy of the system. Then I'd like to compute gradients of some function g
grad(g)(z)
. This works, but as you know the memory requirements grow too fast to be able to make this work for anything but tiny systems. A better idea than to do plain reverse-mode autodiff is to make use of the implicit function theorem which roughly states that in order to compute derivatives of g(z)
at the minimum it does not matter how you actually got there. This makes it possible to introduce a function root_solve(f,init_params,solver)
that computes the energy minimum for the forward pass and memory efficient gradients for the backwards pass by using the custom derivative rules given by jax.custom_vjp
. i.e. this backward pass only involves f
.
It is not possible to define custom gradients with respect to some implicit arguments, so for the backwards pass I need a function f_hat(R,params)
that takes the parameters I want to differentiate with respect to as an explicit argument. As an example, when we are starting from something like this
neighbor_fn, energy_fn = energy.soft_sphere_pair(displacement,sigma=sigma)
force_fn = quantity.force(energy_fn)
and we want to define custom gradients for my function root_solve
then there are 2 ways forward. Let's further assume I want to compute derivatives with respect to sigma
.
I pull out sigma from the force_fn
using closure_convert
within root_solve
, this converted function then is my f_hat
. This way root_solve
does not have to depend on sigma.
z = root_solve_implicit(force_fn,R_init,solver)
The issue with that is, that it simply doesn't work together with neighbor_lists
. closure_convert
does some form of type specialization so it isn't possible to change the size of the neighbor_list
and further it also can't handle function arguments. It might be possible to generalize closure_convert
but I don't think I could do this myself.
I make use of the dynamic_kwargs
. e.g. I can always call force_fn(R,sigma=sigma)
with a potentially different sigma
, so my function f_hat
is just f
itself. This way I have to pass the argument I want differentiate to my root_solve
function.
z = root_solve_explicit(force_fn,R_init,params,solver)
The issue here is that jax finds a leaked tracer coming from the static_kwargs
when defining the energy_fn
. This can be avoided, which is what I tried to sketch with my last comment.
If you have any more questions just hit me up.
Best, Maxi
Hi Maxi,
As of b134155650043079e6346742678cad6b40afbf1f (released in v0.1.22) the neighbor list no longer carries around the cell list function. Let me know if you have any more comments or issues!
All the best, Sam
Since I'm not using closure_convert anymore for the implicit differentiation stuff I don't have an opinion any more on carrying around the cell list function. Hope I didn't make you work for nothing.
Nope not at all, it actually came about naturally from a refactor of the cell list code to make it look more like neighbor lists.
For the implicit differentiation stuff I am working on right now I would like to use
jax.closure_convert
as this makes for a much better user interface. One problem I have with this approach is thatjax.closure_convert
can't handle different cell_list_fn as their hashes do not agree (not completely sure about the cause here). One way to sidestep this whole issue is to simply remove the cell_list_fn fromNeighborList
and store it insideneighbor_list
. What do you think?