jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.07k stars 2.75k forks source link

change grad complex conjugation convention? #4891

Open mattjj opened 3 years ago

mattjj commented 3 years ago

A good way to define grad in terms of jvp is as the vector that satisfies for all v

image

But when the domain of f is a complex vector space,

  1. do we mean the standard inner product on complex vector spaces (a positive-definite sesquilinear form, where we multiply the conjugate of each coordinate in the first vector by the corresponding coordinate in the second vector and sum everything up), or
  2. do we mean the symmetric bilinear form (where we multiply the corresponding components and sum, without conjugation)?

The two options lead to different gradient vectors which differ by an elementwise conjugation.

This is just a convention, and users can always just elementwise-conjugate the result if they don't like our choice. But it seems worth thinking through the pros and cons of each convention.

We currently define grad(f)(x) == vjp(f, x)[1](1.0). This corresponds to the second convention, i.e. where we have jnp.dot(grad(f)(x), v) == jvp(f, (x,), (v,)) for all v and we do not have jnp.vdot(grad(f)(x), v) == jvp(f, (x,), (v,)) for all v (notice jnp.vdot conjugates its first argument where jnp.dot does not).

The first convention seems nice because

On the other hand, the second convention felt more natural to us at one point because

Unless we have a really strong reason to switch conventions, it might be too annoying to existing users for us to change anything...

mattjj commented 3 years ago

On the other hand, the current convention with grad is nicely consistent with jax.jacfwd and jax.jacrev, which give the matrix representing the linearized version of f:

In [18]: jax.jacrev(jnp.sin, holomorphic=True)(1. + 1j)
Out[18]: DeviceArray(0.8337299-0.9888977j, dtype=complex64)

In [19]: jax.jacfwd(jnp.sin, holomorphic=True)(1. + 1j)
Out[19]: DeviceArray(0.8337299-0.9888977j, dtype=complex64)

In [20]: jax.grad(jnp.sin, holomorphic=True)(1. + 1j)
Out[20]: DeviceArray(0.8337299-0.9888977j, dtype=complex64)

That is, we currently have the property that grad produces a Jacobian matrix (requiring that the output is a scalar), but the alternative convention described above is that grad produces a representer vector for the linear map.

I still lead towards switching the convention, at the very least because then I don't have to change some slides I've made stating that grad produces a representer vector :P

EDIT: a crisper version of this point, which @dougalm pointed out, is if you have a function like f = lambda x: x * (1 + 2j), you might expect grad(f)(1.) to be 1 + 2j. Indeed that is the Jacobian, but if we want to follow the "grad gives a representer vector to be used in an inner product" definition above, we'd get 1-2j.

shoyer commented 3 years ago

This change would work out elegantly for custom_linear_solve. For example, right now, jax.scipy.sparse.linalg.cg only sets symmetric=True for real-valued inputs, which means that for complex inputs the backwards pass needs to transpose the input function. On the other hand, every linear function used with cg must satisfy self_adjoint=True, including complex-valued functions. Linear solves with self-adjoint linear operators are much more common than solves with symmetric-complex linear operators.

Another minor factor in favor of this new convention is that it matches the standard definition of the "adjoint" in computational physics.

lukepfister commented 3 years ago

Just curious-- are you still considering changing the convention to Option 1?

shoyer commented 3 years ago

I still think switching the convention would be a wise call, both for convenience and consistency with the mathematical literature and other auto-diff software.

The main downside is that it would break existing gradients of functions with complex valued inputs, so users of complex-valued inputs would need to be aware of the issue. One possible approach to handling this would be a deprecation cycle, wherein explicit options are added for setting the convention in user-facing functions like jax.grad(), and start issuing warnings when it isn't set.

mattjj commented 3 years ago

both for convenience and consistency with the mathematical literature and other auto-diff software.

Well, right now it's consistent with Autograd :)

I'm not sure what's consistent with the mathematical literature. It seems plausible that both conventions are out there. Then we'd just have to choose our favorite, and as with differentiation notation, that's not always the most popular!

I also prefer Option 1, but a convention choice like this seems like such a minor convenience issue, and careful deprecation seems like such a pain, that I'm not sure if it'll ever rise to the top of someone's priority list.

mattjj commented 3 years ago

@shoyer I'm not sure I followed your earlier point about custom_linear_solve. This convention only affects the grad wrapper, not vjp. Can you say more about how it would affect custom_linear_solve?

shoyer commented 3 years ago

@shoyer I'm not sure I followed your earlier point about custom_linear_solve. This convention only affects the grad wrapper, not vjp. Can you say more about how it would affect custom_linear_solve?

Ah, interesting. So what I had in mind here is actually a more pervasive change to JAX internals, one that is perhaps decoupled from the user-facing grad.

My proposal is to essentially to convert JAX's transpose rules into conjugate_transpose rules for complex valued functions. This would imply that VJP becomes the usual vector-matrix product for complex numbers, i.e., jnp.vdot(u, jvp(f, (x,), (v,)[1]) == jnp.vdot(vjp(f, x)(u,), v) or u.H @ J @ v = (u.H @ J) @ v = (J.H @ u).H @ v = u.H @ (J @ v) where .H denotes the conjugate transpose. It thus also directly implies a switched convention for jax.grad, still defined by grad(f)(x) == vjp(f, x)[1](1.0).

This "conjugate transpose" convention is convenient for auto-diff rules like custom_linear_solve because a key question for determining appropriate methods for linear solves is whether the matrix is Hermitian or not, i.e., equal to its conjugate transpose. If so, we can use Cholesky factorization or conjugate gradients; if not, we have to use less efficient methods such as LU factorization. However, the convention that custom_linear_solve needs to know (the symmetric argument) is whether the transpose of the linear operator is equal to the linear operator, which is a different convention.

As I understand it, both TensorFlow and PyTorch use this alternative "conjugate transpose" definition for defining their gradient rules, too, not just user-facing gradients. For example, JAX's transpose rules differ by a complex conjugate from the gradient rules in TensorFlow.

Mathematically, this convention is convenient because it's the same convention used in the literature on the "adjoint state method" for gradient-based optimization with physical constraints. The adjoint method is arguably the original motivation for the development of auto-diff software, before the recent rise of deep learning.

ianwilliamson commented 3 years ago

As I understand it, both TensorFlow and PyTorch use this alternative "conjugate transpose" definition for defining their gradient rules, too, not just user-facing gradients. For example, JAX's transpose rules differ by a complex conjugate from the gradient rules in TensorFlow.

I think this is correct. Having written several VJP "adapters" for functions in TF which have complex inputs / outputs, I've found that an additional conjugation is usually necessary to get the correct gradients in JAX.

Mathematically, this convention is convenient because it's the same convention used in the literature on the "adjoint state method" for gradient-based optimization with physical constraints. The adjoint method is arguably the original motivation for the development of auto-diff software, before the recent rise of deep learning.

+1

shoyer commented 3 years ago

One question might be: who uses complex-valued gradients and VJPs in JAX?

My admittedly myopic impression is that the main users these serve are physical scientists studying fields like quantum mechanics and electromagnetism, both of which are formulated in terms of complex values. I'm quite confident that these users, at least, would be happier with the "conjugate transpose" convention.

In an attempt to answer this quantitatively, I did searches both inside Google and on GitHub for "jax holomorphic=True":

The first project is doing "digital filter design." All the others are doing quantum physics, for which I can attest (as someone who did my PhD in the field) that the conjugate transpose convention is a universal choice.

shoyer commented 3 years ago

On the bright side, we do already have the required holomorphic=True keyword arguments, which makes it very easy to identify users relying on this. The limited number of users suggests that this might not be so disruptive of a change.

shoyer commented 3 years ago

To understand how disruptive changing the internal transpose/VJP convention would be, I searched for custom_vjp in these repositories and turned up exactly one case in netkey by @PhilipVinc of mpi4jax fame.

lukepfister commented 3 years ago

I + others at LANL are working on a jax-backed library for computational imaging / inverse problems (first open source release expected in a month or two). Applications where complex gradients appear include inverse scattering, phase retrieval, and coherent x-ray diffraction imaging.

As a user, the "conjugate transpose" convention would be much simpler.

PhilipVinc commented 3 years ago

I hope that is not a bad fame...

Of course doing quantum physics, as you correctly pointed out, I fully support you switching conventions. However I don't fully understand how that is related to custom_vjp and not to vjp itself. If you are worried about that single usage of custom_vjp, don't. I'll gladly update our code.

In fact, in NetKet we already wrap your vjp and grad in order to support conjugation of the output. (and in order to do that automatically and support arbitrary mixes of real/complex inputs for C->C,R->R and R->C functions... But i wander off topic)

My proposal is to essentially to convert JAX's transpose rules into conjugate_transpose rules for complex valued functions.

There is precendence to that, as this is the convention chosen in Julia's AD rules metapackage ChainRules.jl. Though i believe mathematically that is called an adjoint and not a transposition?

mattjj commented 3 years ago

My point about distinguishing vjp and grad is that a good way to define gradients is as representer vectors for linear maps via an inner product, as in the OP. But if we say VJPs work on covectors then we need not involve an inner product choice (or hence conjugation) at all, though to work with arrays of coefficients we do need to choose a covector basis. So the choices here are formally decoupled in the mathematics, and correspondingly we can decouple the decisions in code.

As a user, the "conjugate transpose" convention would be much simpler.

Perhaps it'd be simpler, but is it really much simpler? It's trivial for users to write their own grad wrapper to implement whatever convention they want, or to add a conjugation to their custom VJP rules. For that reason it doesn't seem worth the effort to change the JAX convention. It does seem worth documenting, though!

@shoyer I haven't yet read and grokked your example of custom_linear_solve; sorry! About holomorphicity though, I think this convention choice is orthogonal to holomorphicity, in the sense that it's about a choice of inner product on the domain and hence applies to C^n -> R functions as well.

To summarize, at the risk of being repetitive: I agree that the conjugated convention is slightly aesthetically nicer and, well, more conventional, but it doesn't seem worth spending time on ~such~ a potentially ~trivial and~ inconsequential convention choice other than documenting it. EDIT: At least, without having yet grokked all the points in this thread, one plausible hypothesis seems that it's a relatively inconsequential convention choice, and I'd like to figure out if that's actually true!

(Edited the last couple sentences of the above paragraph so as not to use dismissive language. Sorry!)

shoyer commented 3 years ago

@shoyer I haven't yet read and grokked your example of custom_linear_solve; sorry! About holomorphicity though, I think this convention choice is orthogonal to holomorphicity, in the sense that it's about a choice of inner product on the domain and hence applies to C^n -> R functions as well.

Ah, good point. Those indeed are indeed the most important case, and they aren't covered by holomorphic=True.

mattjj commented 3 years ago

Sorry, in my post just above I'd originally written that this was "a trivial and inconsequential convention", but that language was way too dismissive, and suggested a level of confidence I do not hold, especially since I haven't yet read and grokked all the points in this thread!

Instead I just wanted to make sure we consider the hypothesis that this choice of convention is pretty inconsequential in the sense that it just amounts to whether folks have to add a conj sometimes when they call grad or at the end of their custom_vjp rule. Even though I haven't grokked the example yet, it sounds like custom_linear_solve is a good test case to investigate whether the difference in convention amounts to "just add a conj" or something trickier.

shoyer commented 2 years ago

On the bright side, we do already have the required holomorphic=True keyword arguments, which makes it very easy to identify users relying on this. The limited number of users suggests that this might not be so disruptive of a change.

I realize now that looking for holomorphic=True is not sufficient for identifying users who would be broken by this change. You only need holomorphic=True for differentiating complex -> complex functions, not complex -> real functions.

lezcano commented 1 year ago

A note why, from a mathematical perspective, PyTorch's gradient is correct and JAX's is the conjugate of the correct gradient:

Most of these concepts (gradients, inner products, vectors, covectors, etc) come from very geometrical concepts first. For example, the gradient is just defined for functions onto the real numbers, and it's defined geometrically as "the direction of steepest ascent". Now, when you have a function from C^n -> R, you can see it as a function R^{2n} -> R, and you can compute its gradient as you'd normally do using the canonical product of R^{2n}. Computationally speaking, you can save yourself the trip from C^n to R^{2n} by noting that the canonical inner product of R^{2n} can be represented as the real inner product on C^n given by Re(x^H y). A short computation should show that this is the same as representing x and y as vectors of length 2n and computing their regular inner product.

Note that these are algebraic formulas to compute the gradient numerically, but geometrically the gradient is well defined as "the direction of steepest ascent". If whatever formulas you use do not give you this vector, then those formulas are not computing the gradient, but computing something else.

Now, the same happens for functions from C^n -> C^m, but in this case you have the adjoint with respect to the canonical real inner products of C^n and C^m.

The confusion in the OP and the rest of this issue is that it's considering a complex inner product on C^n (a sesquilinar form). This has nothing to do with gradients, as the concept of gradient is a concept that appears at the level of real Hilbert spaces (or Riemannian manifolds if you want to go geometrical) rather than complex Hilbert spaces (or Kähler manifolds). A slightly technical example here would be that we can compute the derivatives of the complex SVD and QR decomposition. These use unitary matrices, and the set of unitary matrices is not a Kähler manifold (there's that pesky conjugation in its definition), but only a Riemannian manifold. This hints you to the fact that complex inner products, holomorphic functions, Kähler manifolds and so on do not have (almost) anything to do with this theory.

Some time ago, as I saw that there was a bit of confusion wrt. these concepts also in the PyTorch community, I wrote the following document https://arxiv.org/abs/2207.06114 where I start from basic concepts and I develop the usual coordinate-free theory of differentiation until we have enough tools to define what's going on in the case of complex vector spaces.

lezcano commented 1 year ago

Addendum: The link from the geometric definition of the gradient to its one that's usually given as "the vector g_x such that (df)_x(v) = <g_x, v>" is given by the Cauchy-Schwartz inequality.

Let V be a finite dimensional real vector space with a real inner product <-,->. For a given x \in V you want to compute the direction of maximum ascent of f : V -> R, that is, the direction at which the directional derivative is largest. Mathematically, you want to solve max_{v \in V, \norm{v}=1} (df)_x(v). You know that, since the inner product is positive definite, it is in particular invertible, so you know that you can write (df)_x(v) = <g, v> for some vector g \in V (you can do this with any invertible bilinear form B : V x V -> R). Then, by Cauchy-Schwarz, you know that <g, v> <= \norm{g}\norm{v} with equality iff g = v. As such, the direction of maximum ascent is given by g. Note that if we really want to have \norm{v} = 1, then the direction would be v = g / \norm{g}, but yeah, same thing. We usually define the gradient as g as it also conveniently coincides with "the vector that represents the linear form (df)_x with respect to a given inner product". The fact that it's just a direction is why we need to use and tune learning rates.

Note that you can do the same in C^n, as it can be seen as a real vector space of dimension 2n, and we can equip it with the canonical real inner product Re(x^H y).