Open mattjj opened 4 years ago
On the other hand, the current convention with grad
is nicely consistent with jax.jacfwd
and jax.jacrev
, which give the matrix representing the linearized version of f
:
In [18]: jax.jacrev(jnp.sin, holomorphic=True)(1. + 1j)
Out[18]: DeviceArray(0.8337299-0.9888977j, dtype=complex64)
In [19]: jax.jacfwd(jnp.sin, holomorphic=True)(1. + 1j)
Out[19]: DeviceArray(0.8337299-0.9888977j, dtype=complex64)
In [20]: jax.grad(jnp.sin, holomorphic=True)(1. + 1j)
Out[20]: DeviceArray(0.8337299-0.9888977j, dtype=complex64)
That is, we currently have the property that grad
produces a Jacobian matrix (requiring that the output is a scalar), but the alternative convention described above is that grad
produces a representer vector for the linear map.
I still lead towards switching the convention, at the very least because then I don't have to change some slides I've made stating that grad produces a representer vector :P
EDIT: a crisper version of this point, which @dougalm pointed out, is if you have a function like f = lambda x: x * (1 + 2j)
, you might expect grad(f)(1.)
to be 1 + 2j
. Indeed that is the Jacobian, but if we want to follow the "grad gives a representer vector to be used in an inner product" definition above, we'd get 1-2j
.
This change would work out elegantly for custom_linear_solve
. For example, right now, jax.scipy.sparse.linalg.cg
only sets symmetric=True
for real-valued inputs, which means that for complex inputs the backwards pass needs to transpose the input function. On the other hand, every linear function used with cg
must satisfy self_adjoint=True
, including complex-valued functions. Linear solves with self-adjoint linear operators are much more common than solves with symmetric-complex linear operators.
Another minor factor in favor of this new convention is that it matches the standard definition of the "adjoint" in computational physics.
Just curious-- are you still considering changing the convention to Option 1?
I still think switching the convention would be a wise call, both for convenience and consistency with the mathematical literature and other auto-diff software.
The main downside is that it would break existing gradients of functions with complex valued inputs, so users of complex-valued inputs would need to be aware of the issue. One possible approach to handling this would be a deprecation cycle, wherein explicit options are added for setting the convention in user-facing functions like jax.grad()
, and start issuing warnings when it isn't set.
both for convenience and consistency with the mathematical literature and other auto-diff software.
Well, right now it's consistent with Autograd :)
I'm not sure what's consistent with the mathematical literature. It seems plausible that both conventions are out there. Then we'd just have to choose our favorite, and as with differentiation notation, that's not always the most popular!
I also prefer Option 1, but a convention choice like this seems like such a minor convenience issue, and careful deprecation seems like such a pain, that I'm not sure if it'll ever rise to the top of someone's priority list.
@shoyer I'm not sure I followed your earlier point about custom_linear_solve
. This convention only affects the grad
wrapper, not vjp
. Can you say more about how it would affect custom_linear_solve
?
@shoyer I'm not sure I followed your earlier point about
custom_linear_solve
. This convention only affects thegrad
wrapper, notvjp
. Can you say more about how it would affectcustom_linear_solve
?
Ah, interesting. So what I had in mind here is actually a more pervasive change to JAX internals, one that is perhaps decoupled from the user-facing grad
.
My proposal is to essentially to convert JAX's transpose
rules into conjugate_transpose
rules for complex valued functions. This would imply that VJP becomes the usual vector-matrix product for complex numbers, i.e., jnp.vdot(u, jvp(f, (x,), (v,)[1]) == jnp.vdot(vjp(f, x)(u,), v)
or u.H @ J @ v = (u.H @ J) @ v = (J.H @ u).H @ v = u.H @ (J @ v)
where .H
denotes the conjugate transpose. It thus also directly implies a switched convention for jax.grad
, still defined by grad(f)(x) == vjp(f, x)[1](1.0)
.
This "conjugate transpose" convention is convenient for auto-diff rules like custom_linear_solve
because a key question for determining appropriate methods for linear solves is whether the matrix is Hermitian or not, i.e., equal to its conjugate transpose. If so, we can use Cholesky factorization or conjugate gradients; if not, we have to use less efficient methods such as LU factorization. However, the convention that custom_linear_solve
needs to know (the symmetric
argument) is whether the transpose of the linear operator is equal to the linear operator, which is a different convention.
As I understand it, both TensorFlow and PyTorch use this alternative "conjugate transpose" definition for defining their gradient rules, too, not just user-facing gradients. For example, JAX's transpose rules differ by a complex conjugate from the gradient rules in TensorFlow.
Mathematically, this convention is convenient because it's the same convention used in the literature on the "adjoint state method" for gradient-based optimization with physical constraints. The adjoint method is arguably the original motivation for the development of auto-diff software, before the recent rise of deep learning.
As I understand it, both TensorFlow and PyTorch use this alternative "conjugate transpose" definition for defining their gradient rules, too, not just user-facing gradients. For example, JAX's transpose rules differ by a complex conjugate from the gradient rules in TensorFlow.
I think this is correct. Having written several VJP "adapters" for functions in TF which have complex inputs / outputs, I've found that an additional conjugation is usually necessary to get the correct gradients in JAX.
Mathematically, this convention is convenient because it's the same convention used in the literature on the "adjoint state method" for gradient-based optimization with physical constraints. The adjoint method is arguably the original motivation for the development of auto-diff software, before the recent rise of deep learning.
+1
One question might be: who uses complex-valued gradients and VJPs in JAX?
My admittedly myopic impression is that the main users these serve are physical scientists studying fields like quantum mechanics and electromagnetism, both of which are formulated in terms of complex values. I'm quite confident that these users, at least, would be happier with the "conjugate transpose" convention.
In an attempt to answer this quantitatively, I did searches both inside Google and on GitHub for "jax holomorphic=True":
The first project is doing "digital filter design." All the others are doing quantum physics, for which I can attest (as someone who did my PhD in the field) that the conjugate transpose convention is a universal choice.
On the bright side, we do already have the required holomorphic=True
keyword arguments, which makes it very easy to identify users relying on this. The limited number of users suggests that this might not be so disruptive of a change.
To understand how disruptive changing the internal transpose/VJP convention would be, I searched for custom_vjp
in these repositories and turned up exactly one case in netkey by @PhilipVinc of mpi4jax fame.
I + others at LANL are working on a jax-backed library for computational imaging / inverse problems (first open source release expected in a month or two). Applications where complex gradients appear include inverse scattering, phase retrieval, and coherent x-ray diffraction imaging.
As a user, the "conjugate transpose" convention would be much simpler.
I hope that is not a bad fame...
Of course doing quantum physics, as you correctly pointed out, I fully support you switching conventions. However I don't fully understand how that is related to custom_vjp
and not to vjp
itself.
If you are worried about that single usage of custom_vjp
, don't. I'll gladly update our code.
In fact, in NetKet we already wrap your vjp
and grad in order to support conjugation of the output.
(and in order to do that automatically and support arbitrary mixes of real/complex inputs for C->C,R->R and R->C functions... But i wander off topic)
My proposal is to essentially to convert JAX's transpose rules into conjugate_transpose rules for complex valued functions.
There is precendence to that, as this is the convention chosen in Julia's AD rules metapackage ChainRules.jl. Though i believe mathematically that is called an adjoint and not a transposition?
My point about distinguishing vjp
and grad
is that a good way to define gradients is as representer vectors for linear maps via an inner product, as in the OP. But if we say VJPs work on covectors then we need not involve an inner product choice (or hence conjugation) at all, though to work with arrays of coefficients we do need to choose a covector basis. So the choices here are formally decoupled in the mathematics, and correspondingly we can decouple the decisions in code.
As a user, the "conjugate transpose" convention would be much simpler.
Perhaps it'd be simpler, but is it really much simpler? It's trivial for users to write their own grad
wrapper to implement whatever convention they want, or to add a conjugation to their custom VJP rules. For that reason it doesn't seem worth the effort to change the JAX convention. It does seem worth documenting, though!
@shoyer I haven't yet read and grokked your example of custom_linear_solve; sorry! About holomorphicity though, I think this convention choice is orthogonal to holomorphicity, in the sense that it's about a choice of inner product on the domain and hence applies to C^n -> R functions as well.
To summarize, at the risk of being repetitive: I agree that the conjugated convention is slightly aesthetically nicer and, well, more conventional, but it doesn't seem worth spending time on ~such~ a potentially ~trivial and~ inconsequential convention choice other than documenting it. EDIT: At least, without having yet grokked all the points in this thread, one plausible hypothesis seems that it's a relatively inconsequential convention choice, and I'd like to figure out if that's actually true!
(Edited the last couple sentences of the above paragraph so as not to use dismissive language. Sorry!)
@shoyer I haven't yet read and grokked your example of custom_linear_solve; sorry! About holomorphicity though, I think this convention choice is orthogonal to holomorphicity, in the sense that it's about a choice of inner product on the domain and hence applies to C^n -> R functions as well.
Ah, good point. Those indeed are indeed the most important case, and they aren't covered by holomorphic=True
.
Sorry, in my post just above I'd originally written that this was "a trivial and inconsequential convention", but that language was way too dismissive, and suggested a level of confidence I do not hold, especially since I haven't yet read and grokked all the points in this thread!
Instead I just wanted to make sure we consider the hypothesis that this choice of convention is pretty inconsequential in the sense that it just amounts to whether folks have to add a conj
sometimes when they call grad
or at the end of their custom_vjp
rule. Even though I haven't grokked the example yet, it sounds like custom_linear_solve
is a good test case to investigate whether the difference in convention amounts to "just add a conj
" or something trickier.
On the bright side, we do already have the required
holomorphic=True
keyword arguments, which makes it very easy to identify users relying on this. The limited number of users suggests that this might not be so disruptive of a change.
I realize now that looking for holomorphic=True
is not sufficient for identifying users who would be broken by this change. You only need holomorphic=True
for differentiating complex -> complex functions, not complex -> real functions.
A note why, from a mathematical perspective, PyTorch's gradient is correct and JAX's is the conjugate of the correct gradient:
Most of these concepts (gradients, inner products, vectors, covectors, etc) come from very geometrical concepts first. For example, the gradient is just defined for functions onto the real numbers, and it's defined geometrically as "the direction of steepest ascent". Now, when you have a function from C^n -> R, you can see it as a function R^{2n} -> R, and you can compute its gradient as you'd normally do using the canonical product of R^{2n}. Computationally speaking, you can save yourself the trip from C^n to R^{2n} by noting that the canonical inner product of R^{2n} can be represented as the real inner product on C^n given by Re(x^H y)
. A short computation should show that this is the same as representing x
and y
as vectors of length 2n
and computing their regular inner product.
Note that these are algebraic formulas to compute the gradient numerically, but geometrically the gradient is well defined as "the direction of steepest ascent". If whatever formulas you use do not give you this vector, then those formulas are not computing the gradient, but computing something else.
Now, the same happens for functions from C^n -> C^m, but in this case you have the adjoint with respect to the canonical real inner products of C^n and C^m.
The confusion in the OP and the rest of this issue is that it's considering a complex inner product on C^n (a sesquilinar form). This has nothing to do with gradients, as the concept of gradient is a concept that appears at the level of real Hilbert spaces (or Riemannian manifolds if you want to go geometrical) rather than complex Hilbert spaces (or Kähler manifolds). A slightly technical example here would be that we can compute the derivatives of the complex SVD and QR decomposition. These use unitary matrices, and the set of unitary matrices is not a Kähler manifold (there's that pesky conjugation in its definition), but only a Riemannian manifold. This hints you to the fact that complex inner products, holomorphic functions, Kähler manifolds and so on do not have (almost) anything to do with this theory.
Some time ago, as I saw that there was a bit of confusion wrt. these concepts also in the PyTorch community, I wrote the following document https://arxiv.org/abs/2207.06114 where I start from basic concepts and I develop the usual coordinate-free theory of differentiation until we have enough tools to define what's going on in the case of complex vector spaces.
Addendum: The link from the geometric definition of the gradient to its one that's usually given as "the vector g_x
such that (df)_x(v) = <g_x, v>
" is given by the Cauchy-Schwartz inequality.
Let V be a finite dimensional real vector space with a real inner product <-,->
. For a given x \in V
you want to compute the direction of maximum ascent of f : V -> R
, that is, the direction at which the directional derivative is largest. Mathematically, you want to solve max_{v \in V, \norm{v}=1} (df)_x(v)
. You know that, since the inner product is positive definite, it is in particular invertible, so you know that you can write (df)_x(v) = <g, v>
for some vector g \in V
(you can do this with any invertible bilinear form B : V x V -> R
). Then, by Cauchy-Schwarz, you know that <g, v> <= \norm{g}\norm{v}
with equality iff g = v
. As such, the direction of maximum ascent is given by g
. Note that if we really want to have \norm{v} = 1
, then the direction would be v = g / \norm{g}
, but yeah, same thing. We usually define the gradient as g
as it also conveniently coincides with "the vector that represents the linear form (df)_x
with respect to a given inner product". The fact that it's just a direction is why we need to use and tune learning rates.
Note that you can do the same in C^n
, as it can be seen as a real vector space of dimension 2n
, and we can equip it with the canonical real inner product Re(x^H y)
.
A good way to define
grad
in terms ofjvp
is as the vector that satisfies for all vBut when the domain of f is a complex vector space,
The two options lead to different gradient vectors which differ by an elementwise conjugation.
This is just a convention, and users can always just elementwise-conjugate the result if they don't like our choice. But it seems worth thinking through the pros and cons of each convention.
We currently define
grad(f)(x) == vjp(f, x)[1](1.0)
. This corresponds to the second convention, i.e. where we havejnp.dot(grad(f)(x), v) == jvp(f, (x,), (v,))
for allv
and we do not havejnp.vdot(grad(f)(x), v) == jvp(f, (x,), (v,))
for allv
(noticejnp.vdot
conjugates its first argument wherejnp.dot
does not).The first convention seems nice because
grad
internally but don't conjugate the gradient (perhaps because the original authors weren't thinking about support for complex numbers) would automatically work for complex domains.On the other hand, the second convention felt more natural to us at one point because
grad
correspond to a special case of the Jacobian matrix is a reasonable expectation, e.g. in the scalar case we might expectgrad(lambda x: x * (1 + 2j) == 1 + 2j
rather than getting1 - 2j
. (See comment below.)Unless we have a really strong reason to switch conventions, it might be too annoying to existing users for us to change anything...