DifferentiableUniverseInitiative / JaxPM

JAX-powered Cosmological Particle-Mesh N-body Solver
MIT License
28 stars 14 forks source link

Lagrangian Neural ODE for Fast Approximate Nbody #10

Closed EiffL closed 4 months ago

EiffL commented 2 years ago

Hello hello @FloList @ohahn :-) I wanted to share some results of experimentation I made after our discussions last week, I got some pretty interesting results out of the naive thing I wanted to do. Happy to have your thoughts on this, I'm thinking this might already be close to enough for a little paper.

I have been playing around with the following ideas:

And so the whole thing becomes very close to a neural ODE with minimal parameterisation.

Here is the pseudo-code:

def neural_nbody_ode(state, a, cosmo, params):
    """
    state is a tuple (position, velocities)
    """
    pos, vel = state

    # Painting the particles on a mesh
    delta = cic_paint(jnp.zeros(mesh_shape), pos)

    # Computes gravitational potential
    pot_k = jnp.fft.rfftn(delta) * laplace_kernel(kvec) * longrange_kernel(kvec, r_split=0)

    # Apply a correction filter
    kvec = fftk(mesh_shape)
    kk = jnp.sqrt(sum((ki/pi)**2 for ki in kvec)) / 2
    pot_k = pot_k *(1. + model.apply(params, kk, jnp.atleast_1d(a))) # Where "model" is a parametric b-spline

    # Computes gravitational forces from modified potential
    forces = jnp.stack([cic_read(jnp.fft.irfftn(gradient_kernel(kvec, i)*pot_k), pos) 
                      for i in range(3)],axis=-1)
    forces = forces * 1.5 * cosmo.Omega_m

    # Computes the update of position (drift)
    dpos = 1. / (a**3 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * vel

    # Computes the update of velocity (kick)
    dvel = 1. / (a**2 * jnp.sqrt(jc.background.Esqr(cosmo, a))) * forces

    return dpos, dvel

# Run the simulation
result = odeint(neural_nbody_ode, [pos_i, vel_i], jnp.array([a_i]+scales), cosmo, params, rtol=1e-5, atol=1e-5)

I've tried this out by taking CAMELS IlustrisTNG-DM simulation volumes and using as a loss function the distance between particle positions at each snaphost step. Full notebook here

It's learning some interesting looking filter (seen here at different scale factors): image which essentially scale as "a" but which also move in k a little bit also as a function of a. Here I've parameterised this correction kernel by a b-spline, with a neural network predicting the knots position and value of the spline as a function of a.

And it seems to be doing a good job at least at recovering the correct power spectrum, on a different initial seed and different set of cosmological parameters: image Blue is IllustrisTNG-DM, yellow is without correction, dashed green is with correction.

What is interesting, is that it seems to work ok if I only try to fit the positions, but not as well if I try to fit both position and velocity at the same time. Which I think indicates it's not just refining the estimate of the grav forces but also modifying the dynamics of the simulation, and that just a filter on the potential might not be enough to recover the correct positions and velocities, probably another degree of freedom is needed for that.

EiffL commented 2 years ago

Welllllll.... I've had to do some more refining and additional experiments.... and it works much better than I expected....

A single set of parameters, trained on CAMELS i.e. 25 Mpc/h boxes which I downsampled to 64**3 particles, seems to work on different resolutions, and even different cosmological parameters: image image image

GeoMarX commented 2 years ago

Some results of the implementation of SpectralConv3d in JAX from https://arxiv.org/abs/2010.08895. 68747470733a2f2f66696c65732e6769747465722e696d2f3632356664386532366461303337333938343934663639392f764e4f302f7468756d622f696d6167652e706e67 image