trixi-framework / Trixi.jl

Trixi.jl: Adaptive high-order numerical simulations of conservation laws in Julia
https://trixi-framework.github.io/Trixi.jl
MIT License
522 stars 101 forks source link

Adding reverse mode AD support #827

Open DhairyaLGandhi opened 3 years ago

DhairyaLGandhi commented 3 years ago

Larger PDE solves may benefit from a reverse mode AD and also helps with adding support for neural networks as surrogates or part of training the PDEs themselves, therefore I was wondering what it would take to add Zygote support here. It also comes with a forward mode which may be something that can be used as well.

ranocha commented 3 years ago

I haven't tried to run Zygote on Trixi, so the following comments are not based on experimental evidence but just my general knowledge of the tools.

The biggest problem right now is that Zygote doesn't support mutating operations. Our background in Trixi.jl are explicit discretizations of hyperbolic PDEs. Let's say N is the total number of discrete variables (number of scalar variables times degrees of freedom per scalar variable). Then, the evaluation of our semidiscretization (with otherwise fixed parameters) scales as O(N). However, we need to compute intermediate results to make the coding feasible. Allocating everything again and again also scales as O(N) (in contrast to some deep learning applications, where the computations are relatively more expensive in scaling than allocations). Since our focus has been traditionally on explicit discretizations, we need mutating operations for performance.

As far as I see, one would need to choose between one of the following options:

Thus, we haven't used reverse mode AD so far. See also https://github.com/trixi-framework/Trixi.jl/issues/462.

If my understanding of the tools of the trade isn't correct, I would be happy to learn more about it and discuss how we can enable reverse mode AD with Trixi.jl.

ranocha commented 3 years ago

Having said this, I'm definitely interested in these tools. Having seen their presentation at JuliaCon 2021, Enzyme.jl looks interesting, but I haven't experimented with it yet.

DhairyaLGandhi commented 3 years ago

Depending on the task, it is possible to write adjoints that hide the mutation in specific places, and the adjoints themselves aren't difficult to write. there's also Zygote.Buffer which can handle mutation. That said, Reverse mode would generally work much better with array code. Scalar-like operations would usually be inefficient in reverse mode.

Would it be feasible to identify core functions that would require mutation and write adjoints over them? Typically, this adjoint itself would call Zygote.pullback (or rrule_via_ad). https://github.com/FluxML/Zygote.jl/pull/812 would be another option.

ranocha commented 3 years ago

Scalar-like operations would usually be inefficient in reverse mode.

That's another problem we would have. To make it easy to extend Trixi (different physics, different meshes, different solver types), our basic operations at the lowest level are scalar-like (or handling small static vectors).

Would it be feasible to identify core functions that would require mutation and write adjoints over them?

That might be possible. If you want to invest some effort into this, I will be happy to assist you. A high-level overview of most of our ODE right-hand side (RHS) evaluations looks as follows, see, e.g., https://github.com/trixi-framework/Trixi.jl/blob/275567fe3fa87d825622c974b174d188cea7c126/src/solvers/dgsem_tree/dg_1d.jl#L79

function rhs!(du, u, parameters)
  du .= 0
  add_volume_terms!(du, u, parameters) # boils down to scalar-like operations involving sometimes complicated functions
  add_surface_terms!(du, u, parameters) # same as above
  add_source_terms!(du, u, parameters) # not problematic in most cases and can be ignored in a first step
  scale_by_jacobian!(du, parameters) # easy to write an adjoint - basically a multiplication of `du` by a diagonal matrix
  # `du` is the updated value
end

At the lowest level, the volume and surface terms reduce to non-mutating scalar-like operations such as https://github.com/trixi-framework/Trixi.jl/blob/275567fe3fa87d825622c974b174d188cea7c126/src/equations/compressible_euler_1d.jl#L259 or https://github.com/trixi-framework/Trixi.jl/blob/275567fe3fa87d825622c974b174d188cea7c126/src/equations/compressible_euler_1d.jl#L455

ranocha commented 3 years ago

We could also start with some of the (more experimental) DGMulti solvers. I'm sure @jlchan has some old code lying around where he used some of the infrastructure powering DGMulti solvers for AD: https://arxiv.org/abs/2006.07504. There, they basically derived some custom chain rules (making use of ForwardDiff.jl under the hood) for core routines used in Trixi.jl.

ranocha commented 3 years ago

While implementing the DGMulti solvers in Trixi.jl, we switched to mutating operations to improve the performance (I don't remember the numbers exactly - do you, @jlchan?). However, it might also be feasible to dig out some old code and use that as starting point. Then, we basically need matrix multiplications and pointwise nonlinear operations, which should be feasible to code from scratch for a simple example to get the ball rolling.

jlchan commented 3 years ago

While implementing the DGMulti solvers in Trixi.jl, we switched to mutating operations to improve the performance (I don't remember the numbers exactly - do you, @jlchan?).

Can you remind me which operations these are referring to?

However, it might also be feasible to dig out some old code and use that as starting point. Then, we basically need matrix multiplications and pointwise nonlinear operations, which should be feasible to code from scratch for a simple example to get the ball rolling.

Agreed - this sounds like what I used to do before DGMulti was implemented in Trixi.

ranocha commented 3 years ago

Can you remind me which operations these are referring to?

Basically the difference between your first steps in https://github.com/trixi-framework/Trixi.jl/pull/484 (and https://github.com/trixi-framework/Trixi.jl/pull/557) and the form we have in Trixi.jl right now: No caches, matrix multiplications using the allocating form * instead of mul! etc.

jlchan commented 3 years ago

The allocating version I first tried was similar speed as my Matlab codes, but I don't think the speedup from reducing allocations was huge. However, when DGMulti switched to StructArrays, these non-mutating calls were replaced with StructArray.foreachfield, which requires mutating calls by default.

@DhairyaLGandhi as @ranocha mentioned, we do have a pretty efficient method for computing custom adjoints of our rhs! evaluation (at least for the most interesting solver options IMO). Is my understanding correct that if this is implemented, it could be used with Flux.jl for fast reverse mode AD?