google / jax-cfd

Computational Fluid Dynamics in JAX
Apache License 2.0
740 stars 105 forks source link

JaxNumPy functions for GridArrays/GridVariables #110

Open gemmaellen opened 2 years ago

gemmaellen commented 2 years ago

Hi! This is a great project, and I'm a big fan of both the machine learning applications here and also some of the smaller, helpful structures, in particular base.grids.

Currently, it is possible to add two GridArrays, but it is not possible to add two GridVariables. So this works fine:

import jax_cfd.base.grids as gd
import jax.numpy as jnp

grid = gd.Grid([4,], domain = [(0, 1),])

array_of_values = jnp.array([2.0, 2.0, 3.0, 4.0])

centered_array = grid.center(array_of_values)

print(centered_array + centered_array)

But this throws an exception:

bc = gd.BoundaryConditions((gd.PERIODIC,))

centered_variable = gd.GridVariable(centered, bc)

print(centered_variable + centered_variable)

I'm happy to have a go at implementing this myself, if someone isn't already working on it.

Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy? For example, this throws an exception:

print(jnp.abs(centered_array))

But this works:

import numpy as np
print(np.abs(centered_array))

I assume it's implemented this way because NumPy has an automatic mix-in that we can employ to funnel things to the appropriate JaxNumPy function, but JaxNumPy does not.

shoyer commented 2 years ago

Hi @gemmaellen -- thanks for your interest!

This was an intentional design choice -- GridVariables have boundary conditions, which we don't know how to propagate automatically (unless using periodic boundaries, which aren't really boundary conditions at all). So we only support math on GridArray objects.

Also, am I correct in thinking that the way to use a JaxNumPy function on a GridArray is to call it via NumPy?

This is correct, I agree it's strange. It's for the simple reason that NumPy supports overriding it's functions on new types but JAX doesn't.

gemmaellen commented 2 years ago

Oh, I see! Yes, addition and multiplication would also work for matching homogeneous Dirichet/Neumann BCs, but that's a special case, and it wouldn't extend to other functions like sines and cosines and so on. I assume this is also the reason why the "shift" method on a GridVariable returns a GridArray. Thanks for the explanation!