jobovy / galpy

Galactic Dynamics in python
https://www.galpy.org
BSD 3-Clause "New" or "Revised" License
228 stars 99 forks source link

JAX? #419

Closed MilesCranmer closed 4 years ago

MilesCranmer commented 4 years ago

Hi @jobovy,

I was wondering if you would consider using JAX (https://github.com/google/jax) instead of numpy for some of the galpy backend? A lot of numpy calls can be replaced from: import numpy as np to from jax import numpy as np.

All code is then JIT-compiled with LLVM, parallelized, and GPU-ready, without any further changes. In practice it can be faster than handwritten C code too.

The biggest unique advantage JAX has over numpy or, e.g., numba, is the operation vmap. vmap is functionally similar to multiprocessing.Pool(...).map in that you batch some function over an array in parallel. The key difference is that vmap will actually go through your code, and add an extra axis in all your linear algebra operations. We used it extensively to build our Lagrangian Neural Networks (https://twitter.com/MilesCranmer/status/1237788581478969350?s=20) in order to efficiently batch an operation over an array. I think it would be really nice to be able to call vmap on a galpy operation, since one could then efficiently parallelize eg many evaluations in the same potential, and on a GPU.

Another huge advantage is obviously the grad operation that lets you differentiate any function, even through an integration. That would let you do optimization by gradient descent anywhere in galpy.

Cheers! Miles

MilesCranmer commented 4 years ago

Here's an example I made for a Logarithmic Halo Potential.

The integration over 1e8 timesteps for 5000 different initial conditions takes only 10 seconds in total. And this would be automatically ran on the GPU if one is available.

from jax import numpy as jnp
from jax import jit, grad, vmap, random
from jax.experimental.ode import odeint

#Set up integration:
@jit
def phi(x):
    amp = 1
    core = 1e-8
    q = 0.5
    R = x[0]
    phi = x[1]
    z = x[2]
    return 0.5 * jnp.log(R**2 + (z/q)**2 + core**2)

#The gradient of this potential:
nacceleration = jit(grad(phi))

@jit
def odefunc(y, t):
    x = y[:3]
    xt = y[3:]
    xtt = -nacceleration(x)

    return jnp.concatenate([xt, xtt])

_odefunc = jit(vmap(odefunc, (0, None), 0))

@jit
def vodefunc(yall, t):
    yall = yall.reshape(-1, 6)
    out = _odefunc(yall, t)
    return out.reshape(-1)

@jit
def integrate_all(y0s, ts):
    y0s = y0s.reshape(-1)
    yall = odeint(vodefunc, y0s, ts, mxsteps=500)

    return yall.reshape(yall.shape[0], -1, 6)

#Set up initial conditions:
rng = random.PRNGKey(0)
n = 5000
y0 = jnp.stack([
    random.uniform(rng, (n,), minval=0.1, maxval=5),
    jnp.zeros(n),
    jnp.zeros(n),
    random.normal(rng+1, (n,)),
    jnp.zeros(n),
    random.normal(rng+2, (n,))]).T

ts = jnp.linspace(0, 10000, int(1e8))
yall = integrate_all(y0, ts)
jobovy commented 4 years ago

Hi @MilesCranmer, thanks for the suggestion and for the example.

I've been interested in jax for a while, for some of the reasons you mention, and have been thinking about implementing some jax stuff, but I must say that the main drawback is that I've found jax extremely difficult to install and run. I have wanted to play around with jax, but have been unable to get it to run on my laptop and only managed to get it installed on my work servers by having a student compile a wheel on a virtual machine (in both cases, the issue is I think that the OS is a few years old, but I never have this issue with any other software). Because galpy has a very wide user base, keeping the code easy to install is a priority, so that makes me shy away from including jax (certainly from blanket changes of numpy --> jax.numpy).

As another example of this, I can't actually run your example, because I have jax installed on a machine with a GPU and the memory allocation fails and I also can't figure out how to run it on the CPU (I tried replacing the jit with @partial(jit,backend='cpu'), but that leads to a segfault). [actually, reducing the number of steps to 1e4 I just got it to run, on my Titan Xp GPU, but it's slower than the regular galpy implementation, there is probably something wrong with my jax installation).

So while I'm interested in the type of things that jax makes possible, I'm not inclined to include it in the main galpy code right now. However, it might be interesting to do something like define a potential class that uses jax to compute the accelerations and other derivatives and that could then be used with the jax ODE integrator (one could then also have a potential class that would, say, interpolate other potentials with jax functions [not sure whether these exist] and thus make all potentials useable).

A while ago (when jax was still autograd), I played around a bit with methods that replace the currently loaded numpy in a module with another numpy version (I was using autograd's, but maybe this could work for jax's as well). I wonder whether something like that would be possible, to add jax in a more hacky way for power users? I think you might have to very uniformly import numpy as, say, numpy everywhere to make this easy, but that wouldn't be a big change (I already do import numpy almost everywhere).

I'm wondering what your imagined use case is? Users aren't generally limited by speed in evaluating potentials in Python (I think) and orbit integration is sped up using C (and is quite efficient I think, although I will happily believe that jax's JIT could be faster for some potentials). Currently the C implementations of the potentials aren't exposed to the user, but if there is a need that could be done to allow for faster evaluation (incidentally, I just a few days ago changed the C extension such that it compiles to a libgalpy shared library that can be easily linked to by other C code, as part of an ongoing project to make the C code more portable).

MilesCranmer commented 4 years ago

Hi @jobovy,

Thanks for the quick response!

Regarding the example, I think I made a mistake in adding the jit to the integrate_all function. I think it might be trying to unroll the integrator loop otherwise (hence the memory issue). Also with the JIT, it might take a few runs for it to speed up a call, and for the GPU it might get a speedup until >100,000 integrations in parallel. But maybe there is something else going on, I'm not sure.

However, it might be interesting to do something like define a potential class that uses jax to compute the accelerations and other derivatives and that could then be used with the jax ODE integrator (one could then also have a potential class that would, say, interpolate other potentials with jax functions [not sure whether these exist] and thus make all potentials useable).

I think this is a great idea and probably the best bet for a practical implementation! Indeed I've also had some trouble installing JAX and had to install from source on CentOS to get it to work. My hope is that when FLAX gets released (Google's new deep learning framework on JAX), they will have put more time into expanding the pip+conda wheels and it will be more practical to include in packages.

I'm wondering what your imagined use case is?

My personal interest is in having full gradient information through integrations for ML gradient descent stuff. More generally I could see the GPU side of things being nice for people doing many different integrations in parallel. What do you think? Interested in hearing your thoughts.

One idea could be to have a matplotlib-like mechanism for changing backends:

import galpy
galpy.use('jax')
from galpy import ...

then a flag could be checked in every file for which package to import?

Cheers, Miles

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. If the issue has been resolved since the last activity, please close the issue. Thank you for your contributions.