Closed astanziola closed 2 years ago
Merging #62 (70c2168) into main (70b7e26) will increase coverage by
13.75%
. The diff coverage is62.57%
.
@@ Coverage Diff @@
## main #62 +/- ##
===========================================
+ Coverage 49.78% 63.54% +13.75%
===========================================
Files 10 13 +3
Lines 1418 779 -639
===========================================
- Hits 706 495 -211
+ Misses 712 284 -428
Impacted Files | Coverage Δ | |
---|---|---|
jaxdf/__init__.py | 100.00% <ø> (ø) |
|
jaxdf/version.py | 100.00% <ø> (ø) |
|
jaxdf/util.py | 20.00% <20.00%> (ø) |
|
jaxdf/ode.py | 26.19% <28.57%> (+7.79%) |
:arrow_up: |
jaxdf/operators/functions.py | 54.09% <54.09%> (ø) |
|
jaxdf/operators/differential.py | 55.41% <55.41%> (ø) |
|
jaxdf/operators/magic.py | 59.62% <59.62%> (ø) |
|
jaxdf/discretization.py | 73.07% <72.86%> (+22.34%) |
:arrow_up: |
jaxdf/operators/linear_algebra.py | 75.00% <75.00%> (ø) |
|
jaxdf/core.py | 77.47% <77.27%> (+6.93%) |
:arrow_up: |
... and 6 more |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact)
,ø = not affected
,? = missing data
Powered by Codecov. Last update 70b7e26...70c2168. Read the comment docs.
This is a major PR that introduces many breaking changes to the codebase, arising from what I've learned at the Neurips Differentiable Programming Workshop.
It is annoying that I have to do this already, and I apologise, but hopefully this introduces a cleaner interface for users and gets rid of that annoying
construct-then-execute
pattern that I've been using so far.It also makes the codebase more readable, which I believe is a good thing for a project that aims to be a "hackable" and customizable library.
Changes
Fields are now
PyTrees
The key component is now the
Field
class, which works as the previousDiscretization
class. All discretizations are instances of theField
class.Differently from the previous
Discretization
, aField
object is now aPyTree
that contains both trainable and non-trainable parameters, and can be freely passed to JAX functions as a non-static argument.In particular, by defining a field like a jax-compatible pytree we can exploit the whole tracing infrastructure of jax for taking care of constructing the correct computational graph, therefore the custom made
Tracer
class (and derived ones) is no longer needed.In practice, this means that now we can (for example) define a
FourierSeries
field and directly manipulate it inside of a jax transformable function:There's no need to wrap functions around the
operator
decorator anymore! (But that decorator still exists, as we shall see below)Furthermore, because fields are now class-based pytrees, we can define as many custom methods as we want and use them inside a jittable function!
New Fields can easily be defined using the
jax.tree_util.register_pytree_node_class
decorator (see here):Multiple-dispatch via
operator
decoratorPreviously, the
Operator
class was essentially implementing a multiple-dispatch system (which was a bit of a pain to implement, and was not advanced at all).In practice, for a given operator it was looking at its
name
and calling the corresponding method of (one of) the operands. This approach required to define a dummyOperator
object for each possible operator. Also, it was not possible to easily implement binary opertors whose numerical implementation depends on the type of both operands, without resorting into something like aswitch
statement.Using
plum
, theOperator
class is not needed anymore, and now theoperator
decorator can be used to define multiple-dispatch methods for any operator using type hints!For example, the following code defines the
tanh
operator forContinuous
andOnGrid
fields:Of course, the user can override a specific implementation for a given operator by re-defining the function with the same type signature.
Some operators depend on parameters, such as the gradient operator in finite differences schemes. To deal with parameters, instead of collecting all of them into a dictionary as done before, we adopt the following (non mandatory, but encouraged) convention:
params
argument._op_params
attribute of the returned field. This attribute doesn't exists if no parameters are returned.Operators can't return more than two values, and the first value must be a field.
The main reason for returning the default parameters is to allow the user to reuse them and avoid their initialization when this is computationally demanding (I'm thinking about dense filters in Fourier space and things like that).
As an example, this is the code for the
gradient
operator ofFiniteDifferences
fields:The parameters can be reused as follows
(An alternative probably worth exploring would be the
reap
andplant
transformations defined in the Oryx library.)Because operators are now defined simply using the
operator
decorator, theOperator
andPrimitive
classes are no longer needed.Other changes
rochacbruno/python-project-template
, like