ucl-bug / jaxdf

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
GNU Lesser General Public License v3.0
119 stars 7 forks source link

API for Forwards, Backwards, Central Finite Difference #127

Open jejjohnson opened 1 year ago

jejjohnson commented 1 year ago

I would like to be able to control the finite difference scheme used, i.e. forward, backward or central. Depending upon the PDE, we normally use a custom scheme, e.g. advection --> backwards, diffusion ---> central.


Working Demo

I have a working colab notebook to get a feeling for what I mean. See it here.


Proposed Solution

I don't have a solution but somewhere in the param PyTree I think it is important to specify this (just like the accuracy, order, stepsize, etc).

u = DiscretizationScheme(u_init, domain)

class Params:
    method: str = static_field()
    stagger: iterable(int) = static_field()
    accuracy: int = static_field()

params = Params(method="central", stagger=[0], accuracy=2)

u_grad = gradient(u=u, params=params)

Another possible solution: one could use the FiniteDiffX package backend for generating the coefficients and kernel if one doesn't specify it. There I recently contributed to be able to specify the FD scheme.

Last solution: Just create a custom operator that does exactly all that I've said before. There is an example in the "custom equation of motion" section which does exactly what I want.

jejjohnson commented 1 year ago

I found the winning formula for the simple 1st order FD backwards scheme:

  u = FiniteDifferences.from_grid(u, domain)
  u.accuracy = 2
  u_rhs = -c * gradient(u, stagger=[1])

which generates the coefficients:

from jaxdf.conv import fd_coefficients_fornberg
grid_points = [1, 0]
x0 = 0.0
order = 1
stencil, nodes = fd_coefficients_fornberg(order, grid_points, x0)

which produces:

# stencil, nodes
(array([-1.,  1.]), array([0, 1]))

which is equivalent to: -1 * u[0] + 1 * u[1]

$$ \frac{u{i} - u{i-1}}{dx} = \frac{-u{i-1} + u{i}}{dx} $$

where $i=1$.


What I got wrong was the stagger. When I read stagger, my intuition was like a staggered grid, not a staggered stencil. So, I used -1 originally but then I saw that the FD method uses convolutions so there is a flip for the kernel.

What I found helped me understand was to expose all of the pieces when generating the FD kernel. For example:

# generate nodes based on order, accuracy, method and stagger
nodes = get_fd_nodes(
   derivative: int = 1, 
   accuracy: int = 1, 
   method: str = "central", 
   stagger: int = 0
)

# get coefficients
coeffs = get_fd_coeffs(nodes, derivative: int = 1)

# generate FD kernel (Optional)
kernel = get_fd_kernel(coeffs, domain)
astanziola commented 1 year ago

Hi and thanks for this! I will try to answer to the main points, but let me know if there's something I'm missing


I would like to be able to control the finite difference scheme used, i.e. forward, backward or central

I think you already figured this out, but that clearly means that there's a documentation page missing :) In general, one can do this

u = FiniteDifferences(jnp.zeros((128,)), domain)        # Declare field
u.accuracy = 4                                          # Choose derivative order
params = jops.gradient.default_params(u, stagger=[1])   # Choose grid staggering and get stencil

That returns

params: [array([-0.33333333, -0.5       ,  1.        , -0.16666667,  0.        ])]

In general, every opeartor has the .default_params method that can be called with the same arguments as the operator, and returns the parameters that will be used by the operator. So, one is always free to modify them before calling the operator itself (and potentially construct new operators with it!). For example:

def gradient_with_modified_kernel(u, new_value):
  params = jops.gradient.default_params(u, stagger=[1]) 
  new_params = [params[0].at[4].set(new_value)]    # <-- modify using jax methods, and keep the same PyTree structure
  return gradient(u, params=new_params)    # <-- apply the operator with the modified parameters


I don't have a solution but somewhere in the param PyTree I think it is important to specify this (just like the accuracy, order, stepsize, etc).

The idea of the params input is to collect all differentiable parameters into a pytree, such that jax transformations (including jax.grad) work on params. That allows for performing things like optimization of the numerical parameters of a solver.

For static arguments, like the method='central', I'd be tempted to say that those need to be optional arguments of the gradient operator definition for finite differences. How about defining the function signature like this, for example?

def gradient(
  u: FiniteDifferences, 
  *,
  accuracy: int = 4,  # or, really, this should probably be `order`
  method: 'central',
  stagger = [0]
):
  ...


What I got wrong was the stagger. When I read stagger, my intuition was like a staggered grid, not a staggered stencil. So, I used -1 originally but then I saw that the FD method uses convolutions so there is a flip for the kernel.

That should indeed be a staggered grid, but I always get myself confused with how kernels are applied in convolutions, correlations etc :) Probably makes sense to write a quick test to check the stencils returned against the ones of FiniteDiffX?

Where do we place the accuracy term?

The current API for the gradient (or, generally, for every operation) follows the way julia works (tmbk), where operations are dispatched on the non-keyword arguments. That is, given an operator

def op(x, y, *, options=None, params=None)

then the multiple-dispatch system will choose the appropriate implementation of op based on the types of x and y. This is why at the moment the value option needs to be an optional argument, such that simply calling op(x,y) will generally work regardless of x being a FiniteDifferences or a FourierSeries field, for example.

What is not clear to me is wether the accuracy term needs to be a property of the operator or of the FiniteDifferences field. On a first hand, it only affects the differential operators (including interpolation!) and therefore it seems, on the other hand it feels like a good feature to me to be able to say that a given field is using finite differences of a certain order, and then expect everything to work out. Any suggestion / ideas in this regard? I don't really use finite differences methods enough to know what;s the best choice.

Actions (things to be done based on this discussion)