Open sethaxen opened 4 years ago
See also this issue about why Zygote doesn't do this (tl/dr Zygote basically treats all reals as embedded in the complex numbers): https://github.com/FluxML/Zygote.jl/issues/342 and an update here: https://github.com/FluxML/Zygote.jl/issues/472
Also ccing @MikeInnes because this could change the behavior of Zygote.
+1 for the pushforward / pullback of f(::Real)::Real)
with real sensitivities / adjoints to be real, i.e. for ChainRules to stay real if all its input is real. This convention allows complex pushforwards / pullbacks to be obtained using
invoke(frule, Tuple{typeof(Δx), typeof(f), complex(typeof(x))}, Δx, f, x)
invoke(rrule, Tuple{typeof(f), complex(typeof(x))}, f, x)
Conversely, if we defined pushforwards / pullbacks to be complex for some functions f(::Real)::Real
, then there would be no way to get the real version of the derivatives.
Edit: Actually, this does not work since !(Real <: Complex)
:unamused:
Of course, a similar effect could be achieved using
frule(Δx, f, complex(x))
rrule(f, complex(x))
but this incurs some runtime penalty. In the case of Complex
vs Real
, this penalty would probably be acceptable in most circumstances, but I've been thinking that a similar approach could also be used for similar issues with AbstractArrays
, see e.g. https://github.com/JuliaDiff/ChainRules.jl/issues/191 and https://github.com/JuliaDiff/ChainRules.jl/issues/52. The problem there is that it is not clear whether the adjoint of e.g. f(A,B) = A*B
with respect to A::Diagonal
should be a Diagonal
or a Matrix
. If the above invoke
worked, then that would provide an interface for clarifying the intent: rrule(*, A::Diagonal, B)[2](ΔC)[1]
would be a Diagonal
, and if you wanted a Matrix
instead then you could invoke the Matrix
method of the rrule
. And in this case, it would clearly not be acceptable to call rrule(*, Matrix(A::Diagonal), B)
.
Consider a case where we have a function
f: ℝᵐ → ℂʳ → ℂˢ → ℝⁿ = ℝᵐ → ℝⁿ
, which we can write asf = f₃ ∘ f₂ ∘ f₁
. Typicallyf₁
will produce a complex output by adding, subtracting, multiplying or dividing the real by a complex number or by callingpromote
,complex
,Complex
orcis
. Typicallyf₃
will produce a real output by calling a non-holomorphic function likereal
,imag
,abs
,abs2
,hypot
, orangle
.From https://github.com/JuliaDiff/ChainRulesCore.jl/pull/167, the fact that there are complex intermediates to
f
is just an implementation detail. We could have definedf: ℝᵐ → ℝ²ʳ → ℝ²ˢ → ℝⁿ
, and the pushforwards and pullbacks of this newf
should behave the same.Since in general tangents are derivatives of a primal wrt a real, and co-tangents are derivatives of a real wrt a primal, the pushforward through
f₁: ℝᵐ → ℂʳ
should produce a complex tangent, while the pushforward throughf₃: ℂˢ → ℝⁿ
should produce a real tangent. Conversely, the pullback throughf₃
should produce a complex cotangent, and the pullback throughf₁
should produce a real cotangent.The pushforward case is pretty easy to handle. We can 1) assume that a non-sensical tangent will not be passed and do nothing special (i.e. assume upstream AD did the right thing) or 2) define custom
frule
s that ensure that the produced tangent of unary functionsf₃(::Complex)::Real
is real.The pullback case is more complicated. Right now e.g. in Zygote, unless you create a complex number from reals by calling
complex
, you'll end up pulling back complex numbers through the initial real part of your program, which not only is wasteful but could break assumptions of therrule
s of upstream functions. I propose for the binary functionsf₁
adding customrrule
s forf₁(::Real, ::Complex)::Complex
andf₁(::Complex, ::Real)::Complex
to ensure that the co-tangent pulled back to a real primal is actually real.This came up a point of discussion in JuliaDiff/ChainRules.jl#196, and I would appreciate feedback so we can clarify our conventions here.