FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Complex Number Interfaces #142

Closed MikeInnes closed 5 years ago

MikeInnes commented 5 years ago

Moving this discussion here from #29. We currently have a consistent way of treating complex numbers that's useful in the case of real-valued output (for gradient descent) but not always aligned with other notions of the complex derivative. At best it's only a conjugate away from a useful derivative, and at worst it's partial information (only one column of a 2x2 Jacobian).

@ssfrr suggests separating the gradient function from a derivative function that uses the more traditional definition, and can also make numerical checks that the derivative is a valid operation.

ssfrr commented 5 years ago

To expand a little, D = derivative(f, z) function would return the derivative D such that f(z+dz) ≈ f(z)+D*dz (for small dz).

This generalizes nicely for scalar and vector-valued functions, both Real and Complex:

  1. For scalar R → R or holomorphic C → C functions it would return a scalar Real or Complex
  2. For R^n → R or holomorphic C^n → C it returns a row vector (this is the hermetian transpose of gradient).
  3. For R^n → R^m or holomorphic C^n → C^m it returns a m × n matrix.

Nonholomorphic functions don't have derivatives so they'd throw an error. Checking for nonholomorphism is doable numerically but adds overhead:

Given a function f(z) where z=x+im*y, we can write it as f(x, y) = u(x, y) + im*v(x, y). If the function is holomorphic than it satisfies the Cauchy-Riemann equations:

∂u   ∂v
── = ──
∂x   ∂y

and

∂v     ∂u
── = - ──
∂x     ∂y

We can check this in derivative:

function isholo(f, z)
    _, pb = forward(f, z)
    du, = pb(1)
    dv, = pb(im)
    real(du) ≈ imag(dv) && real(dv) ≈ -imag(du)
end

In the holomorphic case you only need to compute one of pb(1) or pb(im), so perhaps if the user supplied an argument holomorphic=true or something it could skip the check.

It's worth noting that this falls more elegantly out of the Wirtinger stuff that @jrevels and I have been banging at for a bit (see this issue for context, and Jarrett has implemented a lot of it in ChainRules). In the Wirtinger framework the df/dz* term is zero for holomorphic functions so you can just represent it with a Zero type and propagate it through without needing the extra computation.

MikeInnes commented 5 years ago

In the Wirtinger framework the df/dz* term is zero for holomorphic functions so you can just represent it with a Zero type and propagate it through without needing the extra computation.

Presumably this only applies when the function is not composed of any non-holomorphic functions (e.g. x -> real(x) + imag(x)*im. In that case you'd have to propagate a non-zero second Wirtinger derivative for the intermediate result only to find they cancel out at the end, which is more work than just differentiating real(f(x)).

how to handle dispatch for e.g. array primitives that have separate complex and real versions.

This comment is concerning, IIUC; having separate adjoints for functions of real and complex types seems pretty undesirable to me and not overall that elegant, compared to just using the same adjoints everywhere and calculating the Wirtinger derivative separately if it's required.

tkf commented 5 years ago

I saw @MikeInnes commented that ChainRules is in the pipeline https://github.com/FluxML/Zygote.jl/pull/235#issuecomment-503662943 and @oxinabox is working on a PR #291 that does that. But can complex number handling be a blocker for #291 given that ChainRules supports complex differentiation differently? Or is ChainRules interface flexible enough such that Zygote can opt it out?

oxinabox commented 5 years ago

This question is on my radar and has been for a while. I don't know the answer yet, but there will be an answer.

MikeInnes commented 5 years ago

I'm closing this for now as I think we're happy with how Zygote does complex AD (of course ChainRules integration is an open question, but we can discuss that as part of #291).