JuliaDiff / Capstan.jl

A Cassette-based automatic differentiation package for the Julia language
Other
56 stars 7 forks source link

Complex Differentiation #1

Open jrevels opened 6 years ago

jrevels commented 6 years ago

Complex differentiation is one of those features that I had always planned on officially supporting in previous packages, but I never got around to grinding out the details/API consequences.

Since I see Capstan as my "fresh start" to AD now that Cassette is on the table, maybe it's time to dive into complex AD for real (pun intended).

Ideally, Capstan could eventually provide an API that supports differentiating:

Relevant resources include https://arxiv.org/abs/0906.4835 (thanks for introducing me to this material, @ssfrr).

ChrisRackauckas commented 6 years ago

This is the #1 issue. 👍

antoine-levitt commented 6 years ago

Yes please!

I think in general complex numbers should be seen simply as a struct of two reals, and the backend/API should treat them as such (and therefore the jacobian of a C^N to C^M function is a 2N x 2M matrix). Particular cases where a simpler API would be useful:

jrevels commented 6 years ago

@ssfrr and I met yesterday to discuss this, here are some notes from our discussion:

The most general form of a complex derivative of a ℂ → ℂ function can be written as a 2x2 real-valued Jacobian. From this notion, we determined that "correct" complex AD is given directly from "correct" real AD, assuming that all target programs still boil down to real-valued primitives. Properties such as holomorphism can lend extra structure to the aforementioned Jacobian, but it is not necessary to exploit that extra structure from a pure correctness standpoint.

Thus, it would be sufficient for Capstan to provide an API that accepts complex numbers, but only does real-valued AD under the hood (i.e. the Complex{Dual}-style approach).

However, things get more complicated once you start considering array primitives in reverse-mode AD. Here, you might want to define complex array primitives for all the same reasons that it's beneficial to define real-valued array primitives (e.g. to avoid unrolling long loops, to provide a hand-optimized implementation, the kernel for executing the primitive isn't Julia code, etc.). The question then becomes: What must be specified to define a complex array primitive for reverse-mode AD?

Naively, one might think they need to specify the aforementioned Jacobian for their primitive. While boiling down to real-valued differentiation like this is convenient to think about, one's ability to actually write performant code for such calculations will depend heavily on the memory layout of the input arrays. It's quite likely the arrays will be in array-of-struct format (i.e. Array{Complex}), at which point separate interactions/propagations with bulk real/imaginary components can become a huge pain.

To avoid this problem, we can design our API for defining complex-valued primitives to use Wirtinger derivatives instead of derivatives w.r.t. the real/imaginary components. Using Wirtinger derivatives also allows us to easily exploit certain properties for performance gains, for example, for holomorphic primitives, df/d(conj(z)) = 0, and for real-valued primitives with complex input, df/d(conj(z)) = conj(df/dz). These properties essentially allow us to leverage Wirtinger derivatives without paying any extra memory cost compared to formulating the AD on real/imaginary components separately.

Even with this choice, there are still some open design questions - mainly, how to handle dispatch for e.g. array primitives that have separate complex and real versions. Dispatching on element type can be a huge pain in general if it's mixed with container type dispatch, so my best idea at this point is to require that input arrays have eltype defined, and then dispatch on the result of eltype separately from the container. We'll see how it goes, though.

ssfrr commented 6 years ago

I put together a little demo implementation of Complex{Dual} perturbation seeding with test cases and a little f: C->R gradient-descent demo:

https://gist.github.com/ssfrr/6dcb548c06e18e54c35fc89874fad553

ssfrr commented 5 years ago

Some more possibly relevant discussion and info: