FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.48k stars 213 forks source link

Complex gradient on real function with complex intermediates #342

Closed sethaxen closed 3 years ago

sethaxen commented 5 years ago

I came across something odd while working with complex numbers:

julia> using Zygote
julia> Zygote.gradient(x->real(x + 2.0*im), 3.0)
(1.0,)
julia> Zygote.gradient(x->imag(x + 2.0*im), 3.0)
(0.0 + 1.0im,)

While the inputs and outputs of both functions are real, the first produces a real gradient, while the second produces a complex gradient. Is this Zygote's intended behavior?

Version info:

julia> versioninfo()
Julia Version 1.2.0
Commit c6da87ff4b (2019-08-20 00:03 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin18.6.0)
  CPU: Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-6.0.1 (ORCJIT, haswell)

(v1.2) pkg> st --manifest Zygote
    Status `~/.julia/environments/v1.2/Manifest.toml`
  [b552c78f] DiffRules v0.0.10
  [7a1cc6ca] FFTW v0.3.0
  [1a297f60] FillArrays v0.6.4
  [f6369f11] ForwardDiff v0.10.3
  [7869d1d1] IRTools v0.2.3
  [1914dd2f] MacroTools v0.5.1
  [872c559c] NNlib v0.6.0
  [77ba4419] NaNMath v0.3.2
  [ae029012] Requires v0.5.2
  [276daf66] SpecialFunctions v0.7.2
  [e88e6eb3] Zygote v0.3.4
  [700de1a5] ZygoteRules v0.1.0
  [b77e0a4c] InteractiveUtils 
  [37e2e46d] LinearAlgebra 
  [9a3f8284] Random 
  [10745b16] Statistics 
sethaxen commented 5 years ago

FWIW, Tracker and ForwardDiff do what I'd expect:

julia> using Tracker, ForwardDiff
julia> Tracker.gradient(x->imag(x + 2.0*im), 3.0)[1]
0.0 (tracked)
julia> ForwardDiff.derivative(x->imag(x + 2.0*im), 3.0)
0.0
sethaxen commented 5 years ago

For anyone coming across this, a workaround is to directly call complex:

julia> Zygote.gradient(x->imag(x + 2.0*im), 3.0)
(0.0 + 1.0im,)
julia> Zygote.gradient(x->imag(complex(x, 2.0)), 3.0)
(0.0,)
MikeInnes commented 5 years ago

Yes, this is expected. This is the mathematically correct result; otherwise the result would change if you asked for gradient(f, 3.0 + 0im). You can also call real after the gradient is calculated to avoid it, if needed.

mcabbott commented 5 years ago

Are we sure this is the desired behaviour? Promoting Int inputs to float gradients seems friendly, because by asking for a gradient you’re implicitly stating that x is continuous. But I’m less sure about promoting real to complex.

I was scribbling somewhere an example of adding scalars to Pauli matrices, in which it would clearly be crazy to return a matrix-valued gradient for a scalar x. (But I couldn’t get the factors of 2 straight...) Why treat ComplexF64 differently from 2x2 matrices which represent the same algebra?

sethaxen commented 5 years ago

Yes, this is expected. This is the mathematically correct result; otherwise the result would change if you asked for gradient(f, 3.0 + 0im). You can also call real after the gradient is calculated to avoid it, if needed.

So do Tracker and ReverseDiff do the wrong thing here, or is there a fundamental difference between how these packages interpret complex sensitivities? I don't quite follow the above reasoning. While it may be that 3.0 and 3.0 + 0im are mathematically equivalent, they are different types, so I don't see why it's a problem if the results change, the former giving a real gradient and the latter giving a complex gradient.

If this remains the behavior, it would be nice to have this made explicit in the docs, with a recommendation to use something like Zygote.hook(real, x) or some official helper function within a codebase if imaginary parts of adjoints shouldn't be pulled back.

MikeInnes commented 5 years ago

Promoting Int inputs to float gradients seems friendly, because by asking for a gradient you’re implicitly stating that x is continuous.

From Zygote's perspective, there is no fundamental difference between ints and floats; they both represent different (but equally finite and discrete) sets of points on the real line. If a gradient doesn't have a fractional part, it's legitimate to represent it with an Int (and Zygote does this in many cases). Conversely, if there is a fractional part, we don't assume it's OK to drop it based on the input type. The same is true for imaginary parts.

Why treat ComplexF64 differently from 2x2 matrices which represent the same algebra?

We have to have a default representation of complex numbers to use in the standard library, and that default happens to be ComplexF64 (the same for ints, floats etc). But beyond that, there shouldn't be any differences: the same mathematical function will always return the same gradient information at a given point, regardless of how any mathematical objects are represented at runtime.

This is not the only valid behaviour, but it is actually the one with the least special cases as far as implementing Zygote goes.

MikeInnes commented 5 years ago

So do Tracker and ReverseDiff do the wrong thing here, or is there a fundamental difference between how these packages interpret complex sensitivities?

Tracker and ReverseDiff are both self-consistent, insofar as they don't really support complex AD; you're actually differentiating a slightly different real->real function. Another way to look at this is that, to the extent an F64 is a "sparse"/efficient representation of a ComplexF64, the imaginary component is "fixed" to zero rather than just coincidentally being zero (see #163 for more discussion).

Having Zygote's semantics change based on input type could definitely cause problems. For hand-written examples it's not a big deal, but in more complex code you might dynamically choose the type of an object without realising that this changes how your AD works down the line. That may or may not be an acceptable tradeoff for simplicity in some other use cases.

mcabbott commented 5 years ago

If a gradient doesn't have a fractional part, it's legitimate to represent it with an Int (and Zygote does this in many cases). Conversely, if there is a fractional part, we don't assume it's OK to drop it based on the input type. The same is true for imaginary parts.

I agree completely about Int/float, but am not sure that imaginary is the same. Here's my example, which has 3 representations of an R^2 -> R function, which I think ought to be equivalent:

h1(x::Real, y::Real) = x^2 + y^2
h2(x::Real, y::Real) = real( (x+im)^2 + (y+im)^2 ) + 2
h3(x::Real, y::Real) = jreal( (o*x+j)*(o*x+j) + (y*o+j)*(y*o+j) ) + 2

h1(0,0) == 0
all( h1(x,y) ≈ h2(x,y) ≈ h3(x,y) for x in randn(10), y in randn(10) )

where h3 uses a matrix representation of the complex numbers, internally: (o,j) represent the same algebra as (1,im):

j = [0 1; -1 0]; o = [1 0; 0 1];

j * j == -o
j * o == o * j == j
o * o == o

jmag(x::Matrix) = tr(-x * j)/2
jreal(x::Matrix) = tr(x)/2

Now we can calculate gradients, and since h is the height of a nice parabolic bowl, I think we expect zero gradient at the bottom, x,y = 0,0:

Zygote.gradient(h1, 0,0) == (0, 0)
Zygote.gradient(h2, 0,0) == (0 - 2im, 0 - 2im)
Zygote.gradient(h3, 0,0) == (0, 0)

ForwardDiff.gradient(v -> h1(v...), [0,0]) == [0,0]
# ForwardDiff.gradient(v -> h2(v...), [0,0])  # is an error?
ForwardDiff.gradient(v -> h3(v...), [0,0]) == [0,0]

This imaginary direction is not wrong, in the sense that if we walk that way, then we can indeed tunnel out of the bottom of this parabola! And of course this is true in either representation:

h1_(x, y) = x^2 + y^2
h2_(x, y) = real( (x+im)^2 + (y+im)^2 ) + 2
h3_(x, y) = jreal( (o*x+j)*(o*x+j) + (y*o+j)*(y*o+j) ) + 2

h2_(0 + 0.1im, 0 + 0.1im) ≈ -0.42
h3_(0*o + 0.1j, 0*o + 0.1j) ≈ -0.42

But that's not the question which we asked. I think it would be reasonable to insist that you give input x,y = 0+0im, 0+0im if you did in fact want to ask the question about which of 4 directions to walk in, not 2.

And further, to argue that (0 - 2im, 0 - 2im) is the correct answer here, I think one would also need to explain why (-2j, -2j) isn't also mathematically correct. Why should the representation matter?

This is not the only valid behaviour, but it is actually the one with the least special cases as far as implementing Zygote goes.

In this toy example, taking the real part of the gradients right before gradient return them would be fine (and simple). Is it obvious whether doing that is always equivalent to keeping track of which quantities are real the whole way through the calculation?

MikeInnes commented 5 years ago

The difference between j and im here is that from Zygote's perspective, im is a complex number, whereas j is just a bag of real numbers. They are different because im isa Number, which means that all the usual mathematical adjoints are defined on it, whereas j just gets the default structural fallbacks. If you implemented j as a custom number type you'd get the right behaviour.

It's annoying to have that semantic subtlety, but it's also pretty fundamental; the same would be true for a custom float type implemented by bit operations. It would be completely non-differentiable until it's declared as a number. We will always have to choose which set of types gets natural / mathematical derivatives vs. structural / element-wise ones.

Calling real.(...) on the gradient really only works in the simplest cases, when the gradient is a plain array. To handle this more generally we really need to handle it at the AD semantics level. The only sound way to do this is to make <:Real natural and <:Number structural, which is quite a large change and would require a new (largely redundant) complex adjoint library. It's not ideal, though if someone wants to try it on a branch, I'm open to that.

I take the point about complex numbers being higher-dimensional than real ones; this makes the issue closer to the sparse array one than to int/float. I think dimensionality may be a useful (if imperfect) heuristic for deciding what should be natural vs. structural, but the implementation difficulty remains.

mcabbott commented 5 years ago

Thanks for the sparse link BTW, lots of intersting issues there.

I guess my verbose example is trying to argue that problems involving gradients of real-valued loss functions are intrinsically over R, and it's in that world that suddenly getting a complex result is no less surprising than jumping to matrices / quaternions / etc would be. Whereas it sounds like Zygote builds in an assumption that all numbers are really in C, only some don't know it yet. I can see how this would be practical.

Will think some more. Would be interested to know of a pathological example where Calling real.(...) on the final gradient doesn't work (i.e. the result differs from one where all reals inputs have real sensitivities at every step).And will see if I can figure out how difficult implementing a constraint that real inputs have real sensitivities might be.

mcabbott commented 5 years ago

OK maybe this is simpler than expected: see what you think of keeplike in https://github.com/mcabbott/ZygoteRules.jl/blob/real/src/adjoint.jl#L49

That change to @adjoint doesn’t actually affect #356. However in the same spirit I think the gradient for real ought never to produce a complex sensitivity, so I wrote that too.

MikeInnes commented 5 years ago

It's not ideal to capture all of the original arguments in the pullback that way (memory usage is an issue). Also, this won't work for things like Any[1, 2, 3] unless it recurses, which gets increasingly expensive; nor does it handle user-defined structs etc.