ucl-bug / jaxdf

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
GNU Lesser General Public License v3.0
118 stars 7 forks source link

Multiple dispatch and new Field(s) #62

Closed astanziola closed 2 years ago

astanziola commented 2 years ago

This is a major PR that introduces many breaking changes to the codebase, arising from what I've learned at the Neurips Differentiable Programming Workshop.

It is annoying that I have to do this already, and I apologise, but hopefully this introduces a cleaner interface for users and gets rid of that annoying construct-then-execute pattern that I've been using so far.

It also makes the codebase more readable, which I believe is a good thing for a project that aims to be a "hackable" and customizable library.


Changes

Fields are now PyTrees

This was inspired by learning more about equinox, which uses a very clever and simple idea for defining a pytorch-like neural network library that has stateful modules, while at the same time being amendable of all JAX function transformations.

The key component is now the Field class, which works as the previous Discretization class. All discretizations are instances of the Field class.

Differently from the previous Discretization, a Field object is now a PyTree that contains both trainable and non-trainable parameters, and can be freely passed to JAX functions as a non-static argument.

In particular, by defining a field like a jax-compatible pytree we can exploit the whole tracing infrastructure of jax for taking care of constructing the correct computational graph, therefore the custom made Tracer class (and derived ones) is no longer needed.

In practice, this means that now we can (for example) define a FourierSeries field and directly manipulate it inside of a jax transformable function:

@jax.jit
def f(x):
  y = FourierSeries(x.params**2, x.domain)
  return y + 1

u = FourierSeries(theta, domain)  # u.params == theta
v = f(u)                          # v.params == theta**2 + 1
                                  # type(v) == FourierSeries

There's no need to wrap functions around the operator decorator anymore! (But that decorator still exists, as we shall see below)

Furthermore, because fields are now class-based pytrees, we can define as many custom methods as we want and use them inside a jittable function!

New Fields can easily be defined using the jax.tree_util.register_pytree_node_class decorator (see here):

from jax.tree_util import register_pytree_node_class

@register_pytree_node_class
class MyFourierSeries(FourierSeries):
  def custom_method(self):
    ...

Multiple-dispatch via operator decorator

This has been made possible by the wonderful plum library, and inspired by the Julia multiple dispatch system I've learned about after the AbstractDifferentiation.jl talk.

Previously, the Operator class was essentially implementing a multiple-dispatch system (which was a bit of a pain to implement, and was not advanced at all).

In practice, for a given operator it was looking at its name and calling the corresponding method of (one of) the operands. This approach required to define a dummy Operator object for each possible operator. Also, it was not possible to easily implement binary opertors whose numerical implementation depends on the type of both operands, without resorting into something like a switch statement.

Using plum, the Operator class is not needed anymore, and now the operator decorator can be used to define multiple-dispatch methods for any operator using type hints!

For example, the following code defines the tanh operator for Continuous and OnGrid fields:

from jaxdf import operator
from jax import numpy as jnp
from jax.tree_util import tree_map

@operator
def tanh(x: Continuous):
  get_x = x.aux['get_field']
  def _get_fun(params, coords):
    return jnp.tanh(get_x(x))
  return Continuous(x.domain, _get_fun, x.params)

@operator
def tanh(x: OnGrid):
  new_params = tree_map(jnp.tanh, x.params)
  return x.replace_params(new_params)

Of course, the user can override a specific implementation for a given operator by re-defining the function with the same type signature.

Some operators depend on parameters, such as the gradient operator in finite differences schemes. To deal with parameters, instead of collecting all of them into a dictionary as done before, we adopt the following (non mandatory, but encouraged) convention:

Operators can't return more than two values, and the first value must be a field.

The main reason for returning the default parameters is to allow the user to reuse them and avoid their initialization when this is computationally demanding (I'm thinking about dense filters in Fourier space and things like that).

As an example, this is the code for the gradient operator of FiniteDifferences fields:

@operator
def gradient(x: FiniteDifferences, params=None, accuracy=2, staggered='center'):
  if params is None:
    params = _fd_coefficients(1, accuracy, staggered)

  kernel = params
  new_params = _convolve_kernel(x, kernel)
  return FiniteDifferences(new_params, x.domain), params

The parameters can be reused as follows

x = FiniteDifference(grid_values, domain)
gradient_params = gradient(x, accuracy=2, staggered='center')._op_params

# Calling the gradient function with the same parameters
nabla_x = gradient(x, accuracy=2, staggered='center', params=gradient_params)

(An alternative probably worth exploring would be the reap and plant transformations defined in the Oryx library.)

Because operators are now defined simply using the operator decorator, the Operator and Primitive classes are no longer needed.

Other changes

codecov[bot] commented 2 years ago

Codecov Report

Merging #62 (70c2168) into main (70b7e26) will increase coverage by 13.75%. The diff coverage is 62.57%.

@@             Coverage Diff             @@
##             main      #62       +/-   ##
===========================================
+ Coverage   49.78%   63.54%   +13.75%     
===========================================
  Files          10       13        +3     
  Lines        1418      779      -639     
===========================================
- Hits          706      495      -211     
+ Misses        712      284      -428     
Impacted Files Coverage Δ
jaxdf/__init__.py 100.00% <ø> (ø)
jaxdf/version.py 100.00% <ø> (ø)
jaxdf/util.py 20.00% <20.00%> (ø)
jaxdf/ode.py 26.19% <28.57%> (+7.79%) :arrow_up:
jaxdf/operators/functions.py 54.09% <54.09%> (ø)
jaxdf/operators/differential.py 55.41% <55.41%> (ø)
jaxdf/operators/magic.py 59.62% <59.62%> (ø)
jaxdf/discretization.py 73.07% <72.86%> (+22.34%) :arrow_up:
jaxdf/operators/linear_algebra.py 75.00% <75.00%> (ø)
jaxdf/core.py 77.47% <77.27%> (+6.93%) :arrow_up:
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 70b7e26...70c2168. Read the comment docs.