JuliaDiff / ChainRulesCore.jl

AD-backend agnostic system defining custom forward and reverse mode rules. This is the light weight core to allow you to define rules for your functions in your packages, without depending on any particular AD system.
Other
255 stars 62 forks source link

RFC: Rules for real-to-complex and complex-to-real functions #176

Open sethaxen opened 4 years ago

sethaxen commented 4 years ago

Consider a case where we have a function f: ℝᵐ → ℂʳ → ℂˢ → ℝⁿ = ℝᵐ → ℝⁿ, which we can write as f = f₃ ∘ f₂ ∘ f₁. Typically f₁ will produce a complex output by adding, subtracting, multiplying or dividing the real by a complex number or by calling promote, complex, Complex or cis. Typically f₃ will produce a real output by calling a non-holomorphic function like real, imag, abs, abs2, hypot, or angle.

From https://github.com/JuliaDiff/ChainRulesCore.jl/pull/167, the fact that there are complex intermediates to f is just an implementation detail. We could have defined f: ℝᵐ → ℝ²ʳ → ℝ²ˢ → ℝⁿ, and the pushforwards and pullbacks of this new f should behave the same.

Since in general tangents are derivatives of a primal wrt a real, and co-tangents are derivatives of a real wrt a primal, the pushforward through f₁: ℝᵐ → ℂʳ should produce a complex tangent, while the pushforward through f₃: ℂˢ → ℝⁿ should produce a real tangent. Conversely, the pullback through f₃ should produce a complex cotangent, and the pullback through f₁ should produce a real cotangent.

The pushforward case is pretty easy to handle. We can 1) assume that a non-sensical tangent will not be passed and do nothing special (i.e. assume upstream AD did the right thing) or 2) define custom frules that ensure that the produced tangent of unary functions f₃(::Complex)::Real is real.

The pullback case is more complicated. Right now e.g. in Zygote, unless you create a complex number from reals by calling complex, you'll end up pulling back complex numbers through the initial real part of your program, which not only is wasteful but could break assumptions of the rrules of upstream functions. I propose for the binary functions f₁ adding custom rrules for f₁(::Real, ::Complex)::Complex and f₁(::Complex, ::Real)::Complex to ensure that the co-tangent pulled back to a real primal is actually real.

This came up a point of discussion in JuliaDiff/ChainRules.jl#196, and I would appreciate feedback so we can clarify our conventions here.

sethaxen commented 4 years ago

See also this issue about why Zygote doesn't do this (tl/dr Zygote basically treats all reals as embedded in the complex numbers): https://github.com/FluxML/Zygote.jl/issues/342 and an update here: https://github.com/FluxML/Zygote.jl/issues/472

Also ccing @MikeInnes because this could change the behavior of Zygote.

ettersi commented 4 years ago

+1 for the pushforward / pullback of f(::Real)::Real) with real sensitivities / adjoints to be real, i.e. for ChainRules to stay real if all its input is real. This convention allows complex pushforwards / pullbacks to be obtained using

invoke(frule, Tuple{typeof(Δx), typeof(f), complex(typeof(x))}, Δx, f, x)
invoke(rrule, Tuple{typeof(f), complex(typeof(x))}, f, x)

Conversely, if we defined pushforwards / pullbacks to be complex for some functions f(::Real)::Real, then there would be no way to get the real version of the derivatives.


Edit: Actually, this does not work since !(Real <: Complex) :unamused:

Of course, a similar effect could be achieved using

frule(Δx, f, complex(x))
rrule(f, complex(x))

but this incurs some runtime penalty. In the case of Complex vs Real, this penalty would probably be acceptable in most circumstances, but I've been thinking that a similar approach could also be used for similar issues with AbstractArrays, see e.g. https://github.com/JuliaDiff/ChainRules.jl/issues/191 and https://github.com/JuliaDiff/ChainRules.jl/issues/52. The problem there is that it is not clear whether the adjoint of e.g. f(A,B) = A*B with respect to A::Diagonal should be a Diagonal or a Matrix. If the above invoke worked, then that would provide an interface for clarifying the intent: rrule(*, A::Diagonal, B)[2](ΔC)[1] would be a Diagonal, and if you wanted a Matrix instead then you could invoke the Matrix method of the rrule. And in this case, it would clearly not be acceptable to call rrule(*, Matrix(A::Diagonal), B).