ucl-bug / jaxdf

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
GNU Lesser General Public License v3.0
117 stars 7 forks source link

Make a wrapper to hide jaxdf computations #125

Open astanziola opened 1 year ago

astanziola commented 1 year ago

One immediate feature that emerged from the chat with @jejjohnson is the ability to work with fields in a way that allows hiding them from the user, or at least not explicitly working with them.

A common pattern for achieving this is given by the following code:

def my_awesome_func(u: jax.ArrayLike):
  # Declare fields
  N = u.shape
  dx = [0.1,] * len(N)
  u_field = FourierSeries(u, Domain(N, dx))

  # Perform the desired operation using jaxdf
  v_field = some_operator(u_field)

  # Return a simple jax array
  return v_field.on_grid

To simplify the syntax and achieve a cleaner implementation, this pattern can be encapsulated in a decorator, as shown below:

@use_discretization(FourierSeries, dx)
def my_awesome_func(u: jax.ArrayLike):
  return some_operator(u_field)

Here, the use_discretization decorator takes care of packing and unpacking the fields:

def use_discretization(discr_class, dx):
  def _decorator(func):

    def wrapper(u):
      # Declare fields
      N = u.shape
      dx = [0.1,] * len(N)
      u_field = FourierSeries(u, Domain(N, dx))

      # Perform the desired operation using jaxdf
      v_field = func(u_field)

      # Return a simple jax array
      return v_field.on_grid

   return wrapper
return _decorator

Potential issues and things to work out