jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.01k stars 2.75k forks source link

vjp of dtype promotion silently truncates complex numbers #3402

Closed shoyer closed 4 years ago

shoyer commented 4 years ago

As noticed in https://github.com/google/jax/pull/3398:

In [9]: from jax import vjp

In [10]: out, f_vjp = vjp(lambda x: 1j * x, 1.0)

In [11]: f_vjp(1 + 0j)  # wrong!
Out[11]: (DeviceArray(0., dtype=float32),)

In [12]: out, f_vjp = vjp(lambda x: 1j * x, 1.0 + 0j)

In [13]: f_vjp(1 + 0j)  # correct
Out[13]: (DeviceArray(0.+1.j, dtype=complex64),)

Instead, it looks like the cotangent always gets cast to the input dtype.

The result from the first calculation of f_vjp should either be the complex number 1j (preferred) or at the very least should raise an error. Right now it calculates the wrong result.

mattjj commented 4 years ago

This might be working as intended: the cast from f32 to c64 is like the R->C function x -> x + 0i, which if we write as an R->R^2 function has the Jacobian matrix [[1] [0]], the transpose of which is [1 0] (corresponding to the C->R function which drops the imaginary part of a c64 value and produces an f32). Moreover we'd expect the transpose of an a --o b function to be b --o a, where here we should take a to be f32 and b to be c64.

Perhaps it looks surprising in this case because someone might think of 1.0 and 1.+0j as "the same", even though they have different types? (And since this is Python we don't annotate functions with types and instead rely on specializing the function according to the value to which it's applied, in this case an f32 literal.)

shoyer commented 4 years ago

I'm not sure if it makes sense to define "linear" functions between the real and complex numbers like R -> C or R -> C. For example, Wikipedia suggests that the input and output vector space of a linear maps need to over the same field. (In some cases, we might be interested in linear functions over modules but either way it's the same module in both input and output modules over the same ring.)

This model would suggest that every linear function transposed by VJP is either a function R^n -> R^m or a function C^n -> C^m. In which case, dtype promotion would be a convenient shortcut for evaluating these functions, rather than an intrinsic part of their nature.

Intuitively, I think of f_vjp as describing how to perform a calculation based on the (already linear) function lambda x: 1j * x, with the initial value 1.0 (or whatever) just a matter of book-keeping used in the forward calculation and required for partial evaluation / transposing. I do agree that this would be less surprising if Python functions included argument types as part of their definition.

shoyer commented 4 years ago

Well, regardless of whether it is valid mathematically or not, certainly JAX works this way for exact the reason you describe:

def _convert_element_type_transpose_rule(t, *, new_dtype, old_dtype):
  assert t.dtype == new_dtype, (t.dtype, new_dtype)
  return [convert_element_type_p.bind(t, new_dtype=old_dtype,
                                      old_dtype=new_dtype)]
shoyer commented 4 years ago

I realize now that there is at least one canonical "linear" function from R -> C: the real valued Fourier transform np.fft.rfft. (Coincidentally, I tried to reimplement it's transpose in #3398.)

Our current transpose for rfft (which matches TensorFlow's) is only correct if the imaginary part of the answer is silently truncated, i.e., it would be correct op-by-op if rfft(x) were written as fft(real_to_complex(x)).

This feels quite wrong, but I guess for these particular examples real_to_complex and rfft it works, because the transpose is still linear in the full complex vector space.

On the other hand, the transpose of lambda x: 1j * x breaks linearity when composed with functions with other functions, e.g., continuing my first example:

In [15]: f_vjp(1j * (1 + 0j))[0]
Out[15]: DeviceArray(-1., dtype=float64)

In [16]: 1j * f_vjp(1 + 0j)[0]
Out[16]: DeviceArray(0.+0.j, dtype=complex128)

f_vjp is only linear over the real numbers. Due to the way JAX's linear functions are composed, it seems we can only count on mathematical linearity over the real numbers for any functions that involve any real -> complex dtype promotion.

I'm not sure if this is actually a problem or not, but given that these over transpose definitions seem to be useful (at least for gradient-based optimization) perhaps it isn't worth worrying about.

mattjj commented 4 years ago

See #610 for a comment on our convention for R->C and C->R functions, as well as relevant sections of the autodiff cookbook, the Autograd tutorial, and Ch 4 of Dougal's thesis. The convention works out really nicely, and was motivated in large part by thinking about canonical functions like rfft. (Tangentially, I still need to look into #3293 and whether it's a bug in the rule for conj...)

mattjj commented 4 years ago

There was a bug with the autodiff cookbook explanation! Someone found it in #3433, and I rewrote that section to be clearer (and correct-er) in #3434. I think the explanation in there now might help resolve this issue as well.

The Wikipedia page you linked includes this bit:

Occasionally, {\textstyle V}{\textstyle V} and {\textstyle W}{\textstyle W} can be vector spaces over different fields. It is then necessary to specify which of these ground fields is being used in the definition of "linear". [...] For example, the conjugation of complex numbers is an {\textstyle \mathbf {R} }{\textstyle \mathbf {R} }-linear map {\textstyle \mathbf {C} \to \mathbf {C} }{\textstyle \mathbf {C} \to \mathbf {C} }, but it is not {\textstyle \mathbf {C} }{\textstyle \mathbf {C} }-linear, where {\textstyle \mathbf {R} }{\textstyle \mathbf {R} } and {\textstyle \mathbf {C} }{\textstyle \mathbf {C} } are symbols representing the sets of real numbers and complex numbers, respectively.

(This might rely on having a field homomorphism, which we have from R to C.)

I think our convention is basically to follow R-linearity when we're not in the holomorphic case. The cookbook now explains how define Jacobians for (potentially non-holomorphic) C->C functions, essentially by identifying with R^2->R^2 functions. Similarly, for R->C and C->R functions we identify them with R->R^2 and R^2->R functions, respectively, and use R-linearity. These conventions are nice because they specialize to the right thing in the holomorphic case, and they allow someone to evaluate whatever Jacobians they want in terms of these JVPs and VJPs unambiguously (e.g. by vmapping a JVP or VJP over a standard basis of size 2n or 2m to get the Jacobian matrix of a C^n->C^m function).

WDYT?

shoyer commented 4 years ago

One thing that really blows my mind is that with #3398 we can transpose linear functions over the ring of integers Z as well. If we have a function from Z -> R, then the transpose assumes something like "Z-linearity" and results in a truncating map R -> Z, e.g.,

In [11]: jax.linear_transpose(lambda x: jnp.pi * x, 1)(100.0)
Out[11]: (DeviceArray(314, dtype=int64),)

Given how transposing fits into JAX's existing autodiff functionality, I don't really see how we could do this any other way, but I'm also not at all sure this makes sense.

There does seem to be a formal meaning to [transposing a linear map between different modules], and at least on the face of it this fails that test:

In [29]: f = lambda x: jnp.pi * x

In [30]: Tf = lambda x: jax.linear_transpose(f, 1)(x)[0]

In [31]: 2.5 * f(10)
Out[31]: 78.53981633974483

In [32]: Tf(2.5) * 10
Out[32]: DeviceArray(70, dtype=int64)

That said, I'm totally out of my mathematical depth here. I'll see if I can interest one of my colleagues who is an algebraist...

geraschenko commented 4 years ago

I think Matt's comment pointing to #610 is right on the money. The root issue is confusion about the meaning of the covector 1 + 0j in the expression f_vjp(1 + 0j). There are two reasonable(-ish) interpretations which are ask for answers different questions. We can call them the "holomorphic=True" and "holomorphic=False" interpretations.

Background

Backing up briefly, the transpose of an R-linear map f: V --> W is not a linear map W --> V, but a linear map between the dual spaces, Hom_R(W, R) --> Hom_R(V, R), where "Hom_R(X, Y)" means "the space of R-linear maps from X to Y". The transpose f^T is simply composition with f, i.e. if g: W --> R is a linear functional on W, then f^T(g): V --> R is the linear functional which sends a vector v to g(f(v)). Note that this actually makes sense even if f is not linear, or even if V is not a vector space over R.

If we choose a basis {v_1, ..., v_n} for V (and if V is finite-dimensional), then we can identify V with Hom_R(V, R) by thinking of v_i as the linear functional which is 1 on v_i and 0 on v_j for j != i, and extending this to a linear isomorphism V --> Hom_R(V, R). This is what we're usually doing when we talk about transposes of matrices, but clear communication is only possible if everybody agrees on the basis. Much confusion arises when there are multiple implicit bases.

All this applies to vector spaces over C if you just replace every instance of R with C. It even applies to modules over Z or any other commutative ring.

holomorphic=True vs holomorphic=False

Ok, back to the question of what 1 + 0j means when we write f_vjp(1 + 0j). It has to represent a linear functional on C, but are we talking about an R-linear functional or a C-linear functional?

The holomorphic=True interpretation is that 1 + 0j is the C-linear functional given by multiplication by 1 + 0j. In this case f_vjp(1 + 0j) is given by composing "multiply by 1 + 0j" with "multiply by 1j", which is "multiply by 1j". So f_vjp(1 + 0j) = 1j, as Stephan first suggested.

The holomorphic=False interpretation is that 1 + 0j is an R-linear functional on C. To identify it with an R-linear functional, we have to choose a basis for C. The usual basis is {1, 1j}, though algebraically there's nothing special about this basis. Using this basis, 1 + 0j corresponds to the R-linear map which takes the real part of a complex number. In this case f_vjp(1 + 0j) is given by composing "take the real part" with "multiply by 1j", which is "return 0" on R. So f_vjp(1 + 0j) = 0, as Matt said in his "working as intended" comment.

What about Z --> R?

What's transpose(lambda x: pi * x, 1)(100.0)? Again, there are two reasonable interpretations for what 100.0 means.

The "holomorphic=True" interpretation is that it's the R-linear functional "multiply by 100.0". Composing with "multipy by pi", we get "multiply by 100*pi". Note that this is not an element of Z. It's an map from Z to R. We would only be able to identify this with an element of Z if we had an R-basis for Z, which we don't.

The "holomorphic=False" interpretation is that 100.0 is a Z-linear function from R to Z. Good luck deciding which one, because there's no conventionally agreed-upon Z-basis for R. If you decide that the expression 100.0 corresponds to some Z-linear functional g: R --> Z, then the final answer is the Z-linear functional on Z which sends x to g(pi x). This is a Z-linear map from Z to Z. Since Z has a unique choice of Z-basis (up to sign), this map `x --> g(pi x)` must indeed correspond to an integer, namely g(pi).

Btw, it turns out that g(pi) is 0, because there are no non-zero Z-linear functional from R to Z. If g(x) = n for some real number x and some non-zero integer n, then there exists some integer k which does not divide n. By Z-linearity we have k*g(x/k) = n, where g(x/k) must be an integer, contradicting the assumption that k does not divide n. The basic issue here is that the Z-module R doesn't even have a Z-basis (i.e. any spanning set has linear dependencies ... this doesn't come up when you work over a field). You can get an a non-trivial "holomorphic=False" example by considering multiplication by pi as a linear map from Q to R. R does have a basis as a vector space over Q, but again there's no conventionally agreed-upon choice of basis.

NeilGirdhar commented 4 years ago

I've been watching this issue, and I agree with Stephen that this "feels wrong" and would easily cost me many days of debugging if I ran into this. But I also agree with Matt that it makes mathematical sense. After reading Anton's comment, I wanted to ask whether it would be possible to force users to explicitly cast between fields?

In normal numpy code, widening from float32 to float64 is safe, just like widening ("changing field") from float64 to complex128 is safe. But here, it seems that there's some ambiguity when changing field with what to do with the cotangent. If the field-change were done explicitly, that operator could be told what to do in the reverse mode.

For example, out, f_vjp = vjp(lambda x: 1j * x, 1.0) would be an error (cannot cast real to complex implicitly). Instead, you would write:

out, f_vjp = vjp(lambda x: 1j * real_to_complex(x, holomorphic=True), 1.0)

I'm not sure of the interface because I don't understand the assumptions.

By forcing the explicit cast, it saves the very confusing bug that Stephen mentioned?

mattjj commented 4 years ago

@geraschenko Thanks so much for unpacking that! I'm still grokking the latter points, but I wanted to say thanks so much for taking the time to explain things so clearly. It's nice to have a real mathematician weigh in and level-up our knowledge about this stuff.

geraschenko commented 4 years ago

@mattjj Thanks; it's my pleasure. Stephan has an increasingly refined skill for nerd sniping me with this kind of thing :-).

@NeilGirdhar I agree with your intuition that something in Stephan's original example feels wrong. To my eye, the holomorphic=True interpretation is the more reasonable one. It's awfully strange to denote the function "take the real part" by 1.0 + 0.0j.

Even for non-holomorphic (but still real-differentiable) functions, if you're using complex dtypes for your inputs and outputs, that suggests to me that you intend for the complex structure to be meaningful.

The following decomposition may be useful. For a complex covector u, f_vjp(u) breaks up into a holomorphic and antiholomorphic part; it takes a complex tangent vector v to f_vjp_holomorphic(u).dot(v) + f_vjp_antiholomorphic(u).dot(jnp.conj(v)) I think current jax behavior defines f_vjp(u) as f_vjp_holomorphic(u) + jnp.conj(f_vjp_antiholomorphic(u))

An instructive example may be lambda z: z * jnp.conj(z). Taking absolute squares is the sort of thing you'd do to convert a wave function into a probability distribution, so maybe this kind of thing comes up in real use cases (?). What would you expect the vjp of this function to be? Here's what we have right now:

f = lambda z: z * jnp.conj(z)

out, tangent_out = jax.jvp(f, (3.0 + 4.0j,), (1.0 + 2.0j,)) 
print(tangent_out)
# (22+0j)
# Correct. I expect
# tangent_in * jnp.conj(primal_in) + jnp.conj(tangent_in) * primal_in

out, f_vjp = jax.vjp(f, 3.0 + 4.0j)
print(*f_vjp(1.0 + 2.0j))
# (6-8j)
# Correct? I expect
# holomorphic part: jnp.conj(primal_in) * cotangent_in    (3-4j) * (1+2j) = 11+2j
# anti-holomorphic part: primal_in * cotangent_in         (3+4j) * (1+2j) = -5+10j
mattjj commented 4 years ago

Thanks everyone for the great discussion!

I don't think we settled on anything to change here. I'm going to close the issue since it doesn't seem active, but if anyone thinks we should migrate this to a Discussion, let me know!

quantshah commented 2 years ago

@geraschenko and @mattjj Thanks for all the explanations. I am still trying to wrap my head around it all. But I have to say that the situation where one may want to differentiate a function defined between R -> C may come up in quantum physics / quantum information. A simple case is differentiating a quantum operation f(a) = U(a) @ z where a is a set of real numbers defining a Unitary matrix and z is a complex-valued wave vector. Here is a use-case with Jax that learns a Unitary operation using gradient-descent : https://qgrad.readthedocs.io/en/latest/Unitary-Learning-qgrad.html.

If we take @shoyer's original example, a variational quantum circuit does exactly that by taking in a real input and outputs a complex quantity. Admittedly the cost function in many methods is a real-valued function of z in the end. So we have a function that is R -> C - > R and this problem does not arise. But I have other use cases where we are directly interested in the gradients df/da and for that, it took me a while to realize this quirk while writing custom VJPs. Everything was failing silently and it was very strange and frustrating until I came across this discussion. Is there a way to add a check here and give a warning/error message?

Btw jax is used as one of the backends for AD in the Pennylane package that allows differentiation of quantum circuits (https://pennylane.readthedocs.io/en/stable/introduction/interfaces/jax.html).