Closed MikeInnes closed 5 years ago
To expand a little, D = derivative(f, z)
function would return the derivative D
such that f(z+dz) ≈ f(z)+D*dz
(for small dz
).
This generalizes nicely for scalar and vector-valued functions, both Real
and Complex
:
R → R
or holomorphic C → C
functions it would return a scalar Real
or Complex
R^n → R
or holomorphic C^n → C
it returns a row vector (this is the hermetian transpose of gradient
).R^n → R^m
or holomorphic C^n → C^m
it returns a m × n
matrix.Nonholomorphic functions don't have derivatives so they'd throw an error. Checking for nonholomorphism is doable numerically but adds overhead:
Given a function f(z)
where z=x+im*y
, we can write it as f(x, y) = u(x, y) + im*v(x, y)
. If the function is holomorphic than it satisfies the Cauchy-Riemann equations:
∂u ∂v
── = ──
∂x ∂y
and
∂v ∂u
── = - ──
∂x ∂y
We can check this in derivative
:
function isholo(f, z)
_, pb = forward(f, z)
du, = pb(1)
dv, = pb(im)
real(du) ≈ imag(dv) && real(dv) ≈ -imag(du)
end
In the holomorphic case you only need to compute one of pb(1)
or pb(im)
, so perhaps if the user supplied an argument holomorphic=true
or something it could skip the check.
It's worth noting that this falls more elegantly out of the Wirtinger stuff that @jrevels and I have been banging at for a bit (see this issue for context, and Jarrett has implemented a lot of it in ChainRules). In the Wirtinger framework the df/dz*
term is zero for holomorphic functions so you can just represent it with a Zero
type and propagate it through without needing the extra computation.
In the Wirtinger framework the
df/dz*
term is zero for holomorphic functions so you can just represent it with a Zero type and propagate it through without needing the extra computation.
Presumably this only applies when the function is not composed of any non-holomorphic functions (e.g. x -> real(x) + imag(x)*im
. In that case you'd have to propagate a non-zero second Wirtinger derivative for the intermediate result only to find they cancel out at the end, which is more work than just differentiating real(f(x))
.
how to handle dispatch for e.g. array primitives that have separate complex and real versions.
This comment is concerning, IIUC; having separate adjoints for functions of real and complex types seems pretty undesirable to me and not overall that elegant, compared to just using the same adjoints everywhere and calculating the Wirtinger derivative separately if it's required.
I saw @MikeInnes commented that ChainRules is in the pipeline https://github.com/FluxML/Zygote.jl/pull/235#issuecomment-503662943 and @oxinabox is working on a PR #291 that does that. But can complex number handling be a blocker for #291 given that ChainRules supports complex differentiation differently? Or is ChainRules interface flexible enough such that Zygote can opt it out?
This question is on my radar and has been for a while. I don't know the answer yet, but there will be an answer.
I'm closing this for now as I think we're happy with how Zygote does complex AD (of course ChainRules integration is an open question, but we can discuss that as part of #291).
Moving this discussion here from #29. We currently have a consistent way of treating complex numbers that's useful in the case of real-valued output (for gradient descent) but not always aligned with other notions of the complex derivative. At best it's only a conjugate away from a useful derivative, and at worst it's partial information (only one column of a 2x2 Jacobian).
@ssfrr suggests separating the
gradient
function from aderivative
function that uses the more traditional definition, and can also make numerical checks that the derivative is a valid operation.