Open DhairyaLGandhi opened 3 years ago
I haven't tried to run Zygote on Trixi, so the following comments are not based on experimental evidence but just my general knowledge of the tools.
The biggest problem right now is that Zygote doesn't support mutating operations. Our background in Trixi.jl are explicit discretizations of hyperbolic PDEs. Let's say N
is the total number of discrete variables (number of scalar variables times degrees of freedom per scalar variable). Then, the evaluation of our semidiscretization (with otherwise fixed parameters) scales as O(N)
. However, we need to compute intermediate results to make the coding feasible. Allocating everything again and again also scales as O(N)
(in contrast to some deep learning applications, where the computations are relatively more expensive in scaling than allocations). Since our focus has been traditionally on explicit discretizations, we need mutating operations for performance.
As far as I see, one would need to choose between one of the following options:
Thus, we haven't used reverse mode AD so far. See also https://github.com/trixi-framework/Trixi.jl/issues/462.
If my understanding of the tools of the trade isn't correct, I would be happy to learn more about it and discuss how we can enable reverse mode AD with Trixi.jl.
Having said this, I'm definitely interested in these tools. Having seen their presentation at JuliaCon 2021, Enzyme.jl looks interesting, but I haven't experimented with it yet.
Depending on the task, it is possible to write adjoints that hide the mutation in specific places, and the adjoints themselves aren't difficult to write. there's also Zygote.Buffer
which can handle mutation. That said, Reverse mode would generally work much better with array code. Scalar-like operations would usually be inefficient in reverse mode.
Would it be feasible to identify core functions that would require mutation and write adjoints over them? Typically, this adjoint itself would call Zygote.pullback
(or rrule_via_ad
). https://github.com/FluxML/Zygote.jl/pull/812 would be another option.
Scalar-like operations would usually be inefficient in reverse mode.
That's another problem we would have. To make it easy to extend Trixi (different physics, different meshes, different solver types), our basic operations at the lowest level are scalar-like (or handling small static vectors).
Would it be feasible to identify core functions that would require mutation and write adjoints over them?
That might be possible. If you want to invest some effort into this, I will be happy to assist you. A high-level overview of most of our ODE right-hand side (RHS) evaluations looks as follows, see, e.g., https://github.com/trixi-framework/Trixi.jl/blob/275567fe3fa87d825622c974b174d188cea7c126/src/solvers/dgsem_tree/dg_1d.jl#L79
function rhs!(du, u, parameters)
du .= 0
add_volume_terms!(du, u, parameters) # boils down to scalar-like operations involving sometimes complicated functions
add_surface_terms!(du, u, parameters) # same as above
add_source_terms!(du, u, parameters) # not problematic in most cases and can be ignored in a first step
scale_by_jacobian!(du, parameters) # easy to write an adjoint - basically a multiplication of `du` by a diagonal matrix
# `du` is the updated value
end
At the lowest level, the volume and surface terms reduce to non-mutating scalar-like operations such as https://github.com/trixi-framework/Trixi.jl/blob/275567fe3fa87d825622c974b174d188cea7c126/src/equations/compressible_euler_1d.jl#L259 or https://github.com/trixi-framework/Trixi.jl/blob/275567fe3fa87d825622c974b174d188cea7c126/src/equations/compressible_euler_1d.jl#L455
We could also start with some of the (more experimental) DGMulti
solvers. I'm sure @jlchan has some old code lying around where he used some of the infrastructure powering DGMulti
solvers for AD: https://arxiv.org/abs/2006.07504. There, they basically derived some custom chain rules (making use of ForwardDiff.jl under the hood) for core routines used in Trixi.jl.
While implementing the DGMulti
solvers in Trixi.jl, we switched to mutating operations to improve the performance (I don't remember the numbers exactly - do you, @jlchan?). However, it might also be feasible to dig out some old code and use that as starting point. Then, we basically need matrix multiplications and pointwise nonlinear operations, which should be feasible to code from scratch for a simple example to get the ball rolling.
While implementing the
DGMulti
solvers in Trixi.jl, we switched to mutating operations to improve the performance (I don't remember the numbers exactly - do you, @jlchan?).
Can you remind me which operations these are referring to?
However, it might also be feasible to dig out some old code and use that as starting point. Then, we basically need matrix multiplications and pointwise nonlinear operations, which should be feasible to code from scratch for a simple example to get the ball rolling.
Agreed - this sounds like what I used to do before DGMulti
was implemented in Trixi.
Can you remind me which operations these are referring to?
Basically the difference between your first steps in https://github.com/trixi-framework/Trixi.jl/pull/484 (and https://github.com/trixi-framework/Trixi.jl/pull/557) and the form we have in Trixi.jl right now: No caches, matrix multiplications using the allocating form *
instead of mul!
etc.
The allocating version I first tried was similar speed as my Matlab codes, but I don't think the speedup from reducing allocations was huge. However, when DGMulti
switched to StructArrays
, these non-mutating calls were replaced with StructArray.foreachfield
, which requires mutating calls by default.
@DhairyaLGandhi as @ranocha mentioned, we do have a pretty efficient method for computing custom adjoints of our rhs!
evaluation (at least for the most interesting solver options IMO). Is my understanding correct that if this is implemented, it could be used with Flux.jl for fast reverse mode AD?
Larger PDE solves may benefit from a reverse mode AD and also helps with adding support for neural networks as surrogates or part of training the PDEs themselves, therefore I was wondering what it would take to add Zygote support here. It also comes with a forward mode which may be something that can be used as well.