jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.02k stars 2.75k forks source link

Preserving memory between compiled function calls #7550

Open DiffeoInvariant opened 3 years ago

DiffeoInvariant commented 3 years ago

I have a JAX function whose jacobian I need to compute; it's a function from R^(m\times n) to R^(m\times n), so I'm using (JIT-compiled) jax.jacfwd. However, computing this jacobian requires several hundred MB of intermediate allocations, which ends up slowing down the computation noticeably. This function is called at least several hundred to several thousand times per program run with the same argument shapes and dtypes, so this repeated allocation and deallocation ends up costing a lot of compute time.

Would it be possible to implement an argument to jax.jit that would cause the JIT-compiled function to retain the memory it allocates for intermediate objects between calls? I'd be happy to help write this code if it's something y'all think is possible and reasonable (and you want the help); it would make a big difference in my code's total runtime.

Thanks!

example of what I want to be able to do:

def f(*args):
    ... # some jax computation 

args = ... # get the jnp.ndarray instances to linearize around
jac = jax.jit(jax.jacfwd(f), preserve_memory=True)
for _ in range(10000): 
  j = jac(*args) # only allocates memory on the first iteration
  ... # do something with j, get new args for the next iteration that have the same shape and dtype as `args`
mattjj commented 3 years ago

Thanks for the question!

Do you mean that you want to save the time spent allocating/deallocating memory via system calls, i.e. not the time spent on compute and reading/writing to that allocated memory? (If so, I'm surprised that would take a lot of time!)

Is it feasible to write the loop not as Python loop but as a jax.lax.fori_loop? The body of your loop would have to satisfy some requirements (basically being jittable). (If not, can you say why not, in case we can think up a workaround?)

DiffeoInvariant commented 3 years ago

Hi Matt, thanks for the response! I am indeed talking about saving the time spent in system calls, e.g. by using a pool allocator that doesn't free its allocated memory between jitted function calls -- it's hard to tell exactly how much time the system calls account for, but given how much the function allocates and how often it is called, it is likely a nontrivial amount of the total runtime of my program (getting the Jacobian is part of a nonlinear solve in my forward model, so it's called multiple times each iteration inside an optimization loop). I could come up with a concrete estimate of the time spent in system calls doing allocation and deallocation if you're interested in specific timing results, but since that would require writing a parser for the strace output, I'd only do it if that information is useful in solving the problem (happy to do it if it is useful though). The reason it needs such big allocations is that it's computing the jacobian of a function from R^(m\times n) to R^(m\times n) with potentially very large m (m being the number of quadrature points on a mesh we solve a PDE on, n being the number of dofs at each point, usually at least 9).

EDIT: I did some simple timing, and it looks like while the system calls are probably a part of the problem (especially if the allocator is calling calloc instead of malloc), they're probably not nearly as much of a problem as all the reading and writing. At the bottom of this post, i've attached some timing results on my laptop (using c++) and the code used to generate them, where "time to allocate" is the time for std::malloc() plus zeroing the resulting buffer (treating it as an int *), and the "time to free" is just time spent in std::free(). FWIW the JAX program I'm talking about allocates and frees at least about 500 MB on each call to the compiled jacobian function, and if this code is even remotely similar to what's going on inside JAX for memory allocation (and I can verify that the program is indeed allocating and deallocating memory by just watching htop), this could be almost half the time JAX is taking for the entire jacobian computation

Unfortunately, the lax loop primitives don't help too much for several reasons, but mainly because the code that runs in that big loop and calls the jacobian-computing function is mostly not JAX code, but rather external PDE solver code that calls JAX to evaluate a material model, so I can't trace/JIT that loop itself.

Now, if it turns out that applying (some of) the intermediate operators matrix-free and not allocating so much intermediate memory is a possibility, I'd be very interested in that, as it would bypass this problem entirely. I'm happy to share the HLO for one of these large Jacobian-computing functions I've been talking about (or anything else for that matter) if that would help. Basically, we have a nonlinear solver (PETSc's SNES, if you're curious) that solves a PDE defined with FEniCS inside an optimization loop, except we call a JAX function to evaluate part of that PDE. Again, happy to share the actual code I'm talking about with you if that helps (it's all open-source anyway, hosted at https://www.gitlab.com/crikit/crikit)

Timing results (compiled with clang on -O2 optimization):

Time to allocate 1 megabytes is 595 microseconds.
Time to free 1 megabytes is 53 microseconds.

Time to allocate 10 megabytes is 5993 microseconds.
Time to free 10 megabytes is 929 microseconds.

Time to allocate 100 megabytes is 46497 microseconds.
Time to free 100 megabytes is 6376 microseconds.

Time to allocate 200 megabytes is 87819 microseconds.
Time to free 200 megabytes is 12504 microseconds.

Time to allocate 300 megabytes is 140508 microseconds.
Time to free 300 megabytes is 18951 microseconds.

Time to allocate 400 megabytes is 178527 microseconds.
Time to free 400 megabytes is 24942 microseconds.

Time to allocate 500 megabytes is 236468 microseconds.
Time to free 500 megabytes is 30494 microseconds.

Time to allocate 600 megabytes is 275703 microseconds.
Time to free 600 megabytes is 36235 microseconds.

Time to allocate 700 megabytes is 313910 microseconds.
Time to free 700 megabytes is 39993 microseconds.

Time to allocate 800 megabytes is 353498 microseconds.
Time to free 800 megabytes is 45336 microseconds.

Time to allocate 900 megabytes is 394349 microseconds.
Time to free 900 megabytes is 49991 microseconds.

Time to allocate 1000 megabytes is 432804 microseconds.
Time to free 1000 megabytes is 52392 microseconds.

code: malloc_time.cc

mattjj commented 3 years ago

Basically, we have a nonlinear solver (PETSc's SNES, if you're curious)

Ah, I am more familiar with Nintendo's SNES.

Thanks for explaining! Correct me if I'm wrong, but it sounds like the fundamental issue here is that we want to speed up this JAX code as much as possible, which could include (1) speeding up the compiled HLO representing a single execution as well as (2) speeding up its repeated execution.

If that indeed is the fundamental issue, I think we have to start by profiling to see what might be improvable. That may provide better gains than acting on a particular hypothesis too early.

Here are some instructions on how to profile JAX code. For understanding (1) it may be useful to look at the XLA HLO profile, and for understanding (2) it may be most useful to look at the trace viewer.

WDYT?

Now, if it turns out that applying (some of) the intermediate operators matrix-free and not allocating so much intermediate memory is a possibility, I'd be very interested in that, as it would bypass this problem entirely.

Bypassing the problem entirely sounds like a good thing! But we may have to iterate a bit more for me to understand what you mean here.

In a way, JAX's autodiff all works in a matrix-free way: jvp applies a matrix-free Jacobian-vector product to a vector by applying a sequence of primitives' Jacobian-vector products to the appropriate intermediate vectors. In turn, jacfwd works by doing the same operations batched over an entire standard basis of input vectors. The memory cost of evaluating jvp(f, (x,), (v,) should be about twice the memory cost of evaluating f(x), and evaluating jacfwd(f)(x) should be about (1 + m*n) times the memory cost of evaluating f.

If you want the full Jacobian matrix and can tolerate the memory cost of applying jacfwd(f) then I would guess the only issue here is about computer optimization, i.e. we should see if we can minimize the time it takes to execute the XLA:CPU compiled executable, and also the time it takes to launch such executables from the JAX runtime. The first step in that process would be profiling as described above.

Alternatively, if there's some way to bypass the problem, e.g. by using another algorithm which uses jvp rather than jacfwd (e.g. perhaps you only need to estimate some function of this Jacobian, and you have a good polynomial approximation to the function...), then it could be interesting to think through those options first.

WDYT?