Open Jakob-Unfried opened 4 years ago
These sort of multi-linear functions can be written concisely in terms of np.einsum
.
We could make einsum
a JAX primitive. In principle this would make it easier to optimize (or customize) the backwards pass. I experimented this a bit in this notebook, but stopped because I couldn't come up with a compelling example of why it matters.
https://colab.research.google.com/gist/shoyer/98803242f9c0d3c4ddf442f9e063a8df/jax-einsum-primitive.ipynb
It would be cleaner to use a custom derivative rule since we only need to customize derivatives, but we would need both a transpose rule in addition to a JVP rule, which currently isn't supported by custom_jvp
. Alternatively, you could combine the transpose and JVP rules into a VJP rule, using custom_vjp
.
cc @Thenerdstation who interested in these sort of optimizations for https://github.com/google/TensorNetwork
that notebook looks interesting, i will have a read tomorrow, thanks, maybe i can piece together something useful from that.
maybe i worded it in a confusing way, i cannot use einsum directly. for a small enough number of tensors (which is a small enough physical system i am simulating), i use a custom version of the ncon interface that id also in TensorNetwork and that works perfectly.
for larger systems, it becomes too expensive to do all the contractions, even in optimal order. so i have an implementation for approximating the contraction that performs well for the forward/primal pass. but it performs badly when i try to compute gradients with it. as expected, since that break the paradigm of approximating the derivative instead of computing the derivative of the approximation.
so what i did is write a custom_jvp, that calls the primal pass once for every tangent_in.
a simplified & ugly version, since im on mobile:
tangent_out = 0
for n in range(len(tensors)):
tangent_out += fun(tensors[:n], tangents_in[n], tensors[n+1:])
with a for loop. and that works, but it feels terrible to use a for loop and it is very slow too.
what i was hoping for is a better way to do this
edit: btw i am omitting a lot of details here, because i don't expect anyone to be interested, but if that's a wrong assumption feel free to message me
edit2: ok, i found that you did just what i was asking in the notebook
Right, for your use case my example with einsum
is just a proof of concept. You would want to swap out the implementation for your customized approximation.
The key thing is that the high-level tensor contraction needs to be a primitive from the perspective of the auto-diff system. You can either do this with a custom primitive or a custom_vjp
(einsum
is linear, so custom_jvp
won't work).
why is being linear a problem for using custom_jvp
?
thanks for your help so far, appreciate it
why is being linear a problem for using
custom_jvp
?
custom_jvp
works fine for forward mode auto-diff, but it doesn't suffice for reverse mode auto-diff of a linear function, because the function itself will appear inside the custom_jvp
rule. To evaluate reverse mode, JAX needs to be able to transpose any function that is used in a JVP, which currently requires either using custom_vjp
or writing a new Primitive.
ah i see, thank you.
i guess that also contributed to the slow runtime of my custom_jvp, since grad was running in forward mode.
i will see what i can do with it tomorrow and post an update.
I guess we could try implementing a function decorator that asserts the a given function is multilinear (e.g. @multilinear
) and have it work similarly to the invertible AD I started working on in #3232 (that only requires you to mark the function as @invertible
).
I am struggling to deal with the transpose rule. Is there some documentation, of what it should do? I mean both the input/output structure, an what mathematical object it should compute
@Jakob-Unfried Have you looked at https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html?
@gnecula Thank you.
FYI: At https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html the notebook is not displayed correctly for me.
It works at https://jax.readthedocs.io/en/stable/notebooks/How_JAX_primitives_work.html (stable instead of latest)
I think that this was fixed a few hours ago, I am not sure why readthedocs did not pick it up.
I figured that it shouldn't be too hard to implement, so I decided to just go ahead. Can you please check if this implementation satisfies your requirements?
import jax
import jax.numpy as jnp
def _with_ith(l, i, x, fun):
o = l[i]
l[i] = x
try:
return fun(*l)
finally:
l[i] = o
def multilinear(f):
fml = jax.custom_jvp(f)
@fml.defjvp
def multilinear_jvp(primals, tangents):
# TODO: This assumes flat inputs!
args = list(primals)
return f(*primals), sum(_with_ith(args, i, t, f) for i, t in enumerate(tangents))
return fml
It should work with both forward-mode and reverse-mode AD, but it will just use the original implementation of your function to compute the derivatives instead of trying to differentiate those. Here is a small example that shows that it works in a simple example:
call = lambda f: lambda *args: jax.core.call(jax.linear_util.wrap_init(lambda *args: [f(*args)]), *args)[0]
@multilinear
@call
def f(x, y, z):
return x * y * z
x = jnp.ones((7,))
i = (x, x, x)
print(jax.make_jaxpr(lambda args, tans: jax.jvp(f, args, tans))(i, i))
print(jax.make_jaxpr(lambda args, ctans: jax.vjp(f, *args)[1](ctans))(i, x))
The call
part is not strictly necessary and in your use case you should just use @multilinear
. However, it makes the jaxprs I print nicer by showing more clearly that the original function ends up being computed as many times as there are arguments, but each time one of them gets swapped for a tangent value.
For example, the forward derivative of f
is just this (a
, b
, c
are primal inputs, d
, e
, f
are tangents wrt each primal input):
@apaszke your example looks really elegant. Here's a test case for matrix-multiplication:
@multilinear
@call
def g(x, y):
return x @ y
x = jnp.ones((3, 4))
y = jnp.ones((4, 5))
i = (x, y)
z = jnp.ones((3, 5))
print(jax.make_jaxpr(g)(*i)) # works
print(jax.make_jaxpr(lambda args, tans: jax.jvp(g, args, tans))(i, i)) # works
print(jax.make_jaxpr(lambda args, ctans: jax.vjp(g, *args)[1](ctans))(i, z)) # works
I'm still not sure how you manage to avoid the need for a transpose rule!
EDIT: this actually works correctly, I had a bug in my original example
I see now that @apaszke's example works because JAX transposes the implementation of the multilinear function.
I suspect this will not suffice for @Jakob-Unfried's use-case, because it doesn't resolve the
the fundamental issue that transpose(approx(multilinear_fun))
is not a good substitute for approx(transpose(multilinear_fun))
. We need some way to customize the transpose rule, too.
Chiming in for a bit of context, as I understand it. Jax computes derivatives in two steps:
f(dx) = dy
, the transpose is a function f*(dy*) = dx*
such that <f(dx), dy*> = <dx, f*(dy*)>
(where that's an inner product). You can think of this as locally transforming forward mode (given a change in x
, does our function change in the direction we want?) into reverse mode (given a change we want, what is the best way to change our input x
?) For matrix multiply, the transpose is a multiplication with a transposed matrix. For a reduce-sum, the transpose is a broadcast.The approach described above (and custom_jvp
in general) tells JAX how you want your function to be linearized. When you compute a jvp
of it, it then just used this function you gave it instead of doing its normal linearization. When you use vjp
, it uses your linearization but still transposes it using its own set of rules. (And in fact if you do custom_jvp
but use nonlinear stuff in it, you will probably get some error about a missing transpose rule!)
There's a chance that the automatic transposition will do the right thing, but my guess is that it won't in general. That is @shoyer's point above, that transposing an approximation might not give the right approximate transpose (it might not work at all if you use anything nonlinear, or it might not be a good approximation if it works). If it doesn't work, the two existing options that I know about are custom_vjp
(which lets you skip both the first and the second step and write f*
directly) and implementing a custom primitive with a transpose rule (which is similar, but requires more knowledge of Jax internals; on the other hand it means that you would probably get forward mode, reverse mode, and higher-order derivatives for free).
I'm not sure there's a general way to automatically derive an approximate transpose from an approximate linear function; that seems hard to do in general, since it probably depends on exactly how your approximation works.
edit: One more detail that's relevant here is that, since a multilinear function is linear in each argument separately, it effectively has a different transpose with respect to each of its arguments. (You can almost see this in the @multilinear
decorator, since it splits the single call to your function into n
calls to your function; after that, each call will be transposed separately with respect to the single tangent argument.)
update from me:
i could not come up with a working transpose rule.
what @shoyer did in his notebook with the transpose rule for einsum is just not possible with my implementation. what is needed here is, in essence, "contract everything but leave one tensor/operand out, the uncontracted axes, that would normally have been contracted with that one tensor are the axes of the output of the transpose rule". if i could do that, i could use that to directly write down the vjp. that contraction (times the output cotangent) is the cotangent for that one tensor.
in the linked notebook, @shoyer could use einsum in the transpose rule, but with different index strings. my approximate function does not have that flexibility. it is like an approximate einsum with fixed (and only quite specifc) index strings .
i am quite certain that unless you can supply jax with an (approximate or exact) implementation for "contract everything except for one tensor", there is no chance of having either a transpose or a vjp rule. but if you can, both are simple
so it looks like i am left with the jvp and forward mode
Tanks for all the insights, i learned a lot from this thread
TL;DR: Is there a way to tell jax that a function of many variables is linear in each of them and have it handle the gradient pass efficiently?
I have a function
fun(*tensors)
of many tensors, the result of which is a (complex) number.fun
is the contraction of a tensor network, that makes it linear, i.e. it fulfillsfun(tensors[:i], A + B, tensors[i+1:]) = fun(tensors[:i], A, tensors[i+1:]) + fun(tensors[:i], B, tensors[i+1:])
for alli
You can imagine the workings of it like this
for appropriately chosen
axes
This exact contraction is, however, too expensive in most cases. For the just evaluating
fun
, there are well-known approximation methods, that involve compression via SVD.If I want to compute the gradient of
fun
, however, this is a bad idea. It is both inefficient (since the approximations make the computation graph large and its connectivity complicated) and inaccurate.Since it is linear, the JVP and VJP rules are easy to write down analytically (say for the JVP, for each tensor in tensor_list there is a term which is again
fun
but this tensor is replaced with its tangent and the tangent_out is the sum of all of these)I could not find a way to implement such a JVP rule without using an inefficient for loop over the tensors.
Is there some jax functionality to do this in a clean way? I read something about the inner workings of custom_jvp, where it says it can automatically transpose the linear jacobian and create a vjp rule from the jvp. So i thought maybe there is a way to tell the gradient pass that this function is a linear function and handle it automatically?