google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.93k stars 2.74k forks source link

Question: efficient gradient computation for multi-linear function #3366

Open Jakob-Unfried opened 4 years ago

Jakob-Unfried commented 4 years ago

TL;DR: Is there a way to tell jax that a function of many variables is linear in each of them and have it handle the gradient pass efficiently?

I have a function fun(*tensors) of many tensors, the result of which is a (complex) number.

fun is the contraction of a tensor network, that makes it linear, i.e. it fulfills fun(tensors[:i], A + B, tensors[i+1:]) = fun(tensors[:i], A, tensors[i+1:]) + fun(tensors[:i], B, tensors[i+1:]) for all i

You can imagine the workings of it like this

def fun(*tensors):
    tmp = tensors[0]
    for n, tens in enumerate(tensors[1:]):
        tmp = np.tensordot(tmp, tens, axes[n])
    return tmp

for appropriately chosen axes

This exact contraction is, however, too expensive in most cases. For the just evaluating fun, there are well-known approximation methods, that involve compression via SVD.

If I want to compute the gradient of fun, however, this is a bad idea. It is both inefficient (since the approximations make the computation graph large and its connectivity complicated) and inaccurate.

Since it is linear, the JVP and VJP rules are easy to write down analytically (say for the JVP, for each tensor in tensor_list there is a term which is again fun but this tensor is replaced with its tangent and the tangent_out is the sum of all of these)

I could not find a way to implement such a JVP rule without using an inefficient for loop over the tensors.

Is there some jax functionality to do this in a clean way? I read something about the inner workings of custom_jvp, where it says it can automatically transpose the linear jacobian and create a vjp rule from the jvp. So i thought maybe there is a way to tell the gradient pass that this function is a linear function and handle it automatically?

shoyer commented 4 years ago

These sort of multi-linear functions can be written concisely in terms of np.einsum.

We could make einsum a JAX primitive. In principle this would make it easier to optimize (or customize) the backwards pass. I experimented this a bit in this notebook, but stopped because I couldn't come up with a compelling example of why it matters. https://colab.research.google.com/gist/shoyer/98803242f9c0d3c4ddf442f9e063a8df/jax-einsum-primitive.ipynb

It would be cleaner to use a custom derivative rule since we only need to customize derivatives, but we would need both a transpose rule in addition to a JVP rule, which currently isn't supported by custom_jvp. Alternatively, you could combine the transpose and JVP rules into a VJP rule, using custom_vjp.

cc @Thenerdstation who interested in these sort of optimizations for https://github.com/google/TensorNetwork

Jakob-Unfried commented 4 years ago

that notebook looks interesting, i will have a read tomorrow, thanks, maybe i can piece together something useful from that.

maybe i worded it in a confusing way, i cannot use einsum directly. for a small enough number of tensors (which is a small enough physical system i am simulating), i use a custom version of the ncon interface that id also in TensorNetwork and that works perfectly.

for larger systems, it becomes too expensive to do all the contractions, even in optimal order. so i have an implementation for approximating the contraction that performs well for the forward/primal pass. but it performs badly when i try to compute gradients with it. as expected, since that break the paradigm of approximating the derivative instead of computing the derivative of the approximation.

so what i did is write a custom_jvp, that calls the primal pass once for every tangent_in.

a simplified & ugly version, since im on mobile:

tangent_out = 0
for n in range(len(tensors)):
    tangent_out += fun(tensors[:n], tangents_in[n], tensors[n+1:])

with a for loop. and that works, but it feels terrible to use a for loop and it is very slow too.

what i was hoping for is a better way to do this

edit: btw i am omitting a lot of details here, because i don't expect anyone to be interested, but if that's a wrong assumption feel free to message me

edit2: ok, i found that you did just what i was asking in the notebook

shoyer commented 4 years ago

Right, for your use case my example with einsum is just a proof of concept. You would want to swap out the implementation for your customized approximation.

The key thing is that the high-level tensor contraction needs to be a primitive from the perspective of the auto-diff system. You can either do this with a custom primitive or a custom_vjp (einsum is linear, so custom_jvp won't work).

Jakob-Unfried commented 4 years ago

why is being linear a problem for using custom_jvp?

thanks for your help so far, appreciate it

shoyer commented 4 years ago

why is being linear a problem for using custom_jvp?

custom_jvp works fine for forward mode auto-diff, but it doesn't suffice for reverse mode auto-diff of a linear function, because the function itself will appear inside the custom_jvp rule. To evaluate reverse mode, JAX needs to be able to transpose any function that is used in a JVP, which currently requires either using custom_vjp or writing a new Primitive.

Jakob-Unfried commented 4 years ago

ah i see, thank you.

i guess that also contributed to the slow runtime of my custom_jvp, since grad was running in forward mode.

i will see what i can do with it tomorrow and post an update.

apaszke commented 4 years ago

I guess we could try implementing a function decorator that asserts the a given function is multilinear (e.g. @multilinear) and have it work similarly to the invertible AD I started working on in #3232 (that only requires you to mark the function as @invertible).

Jakob-Unfried commented 4 years ago

I am struggling to deal with the transpose rule. Is there some documentation, of what it should do? I mean both the input/output structure, an what mathematical object it should compute

gnecula commented 4 years ago

@Jakob-Unfried Have you looked at https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html?

Jakob-Unfried commented 4 years ago

@gnecula Thank you.

FYI: At https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html the notebook is not displayed correctly for me.

It works at https://jax.readthedocs.io/en/stable/notebooks/How_JAX_primitives_work.html (stable instead of latest)

gnecula commented 4 years ago

I think that this was fixed a few hours ago, I am not sure why readthedocs did not pick it up.

apaszke commented 4 years ago

I figured that it shouldn't be too hard to implement, so I decided to just go ahead. Can you please check if this implementation satisfies your requirements?

import jax                                                                           
import jax.numpy as jnp                                                              

def _with_ith(l, i, x, fun):                                                         
  o = l[i]                                                                           
  l[i] = x                                                                           
  try:                                                                               
    return fun(*l)                                                                   
  finally:                                                                           
    l[i] = o                                                                         

def multilinear(f):                                                                  
  fml = jax.custom_jvp(f)                                                            

  @fml.defjvp                                                                        
  def multilinear_jvp(primals, tangents):                                            
    # TODO: This assumes flat inputs!                                                
    args = list(primals)                                                             
    return f(*primals), sum(_with_ith(args, i, t, f) for i, t in enumerate(tangents))

  return fml                                                                         

It should work with both forward-mode and reverse-mode AD, but it will just use the original implementation of your function to compute the derivatives instead of trying to differentiate those. Here is a small example that shows that it works in a simple example:

call = lambda f: lambda *args: jax.core.call(jax.linear_util.wrap_init(lambda *args: [f(*args)]), *args)[0]

@multilinear                                                                                               
@call                                                                                                      
def f(x, y, z):                                                                                            
  return x * y * z                                                                                         

x = jnp.ones((7,))                                                                                         
i = (x, x, x)                                                                                              
print(jax.make_jaxpr(lambda args, tans: jax.jvp(f, args, tans))(i, i))                                     
print(jax.make_jaxpr(lambda args, ctans: jax.vjp(f, *args)[1](ctans))(i, x))                               

The call part is not strictly necessary and in your use case you should just use @multilinear. However, it makes the jaxprs I print nicer by showing more clearly that the original function ends up being computed as many times as there are arguments, but each time one of them gets swapped for a tangent value. For example, the forward derivative of f is just this (a, b, c are primal inputs, d, e, f are tangents wrt each primal input):

``` { lambda ; a b c d e f. let g = call[ call_jaxpr={ lambda ; a b c. let d = mul a b e = mul d c in (e,) } name= ] a b c h = call[ call_jaxpr={ lambda ; a b c. let d = mul a b e = mul d c in (e,) } name= ] d b c i = add h 0.0 j = call[ call_jaxpr={ lambda ; a b c. let d = mul a b e = mul d c in (e,) } name= ] a e c k = add i j l = call[ call_jaxpr={ lambda ; a b c. let d = mul a b e = mul d c in (e,) } name= ] a b f m = add k l in (g, m) } ```
shoyer commented 4 years ago

@apaszke your example looks really elegant. Here's a test case for matrix-multiplication:

@multilinear                                                                                               
@call                                                                                                      
def g(x, y):                                                                                            
  return x @ y

x = jnp.ones((3, 4))
y = jnp.ones((4, 5))
i = (x, y)

z = jnp.ones((3, 5))

print(jax.make_jaxpr(g)(*i))  # works
print(jax.make_jaxpr(lambda args, tans: jax.jvp(g, args, tans))(i, i))  # works
print(jax.make_jaxpr(lambda args, ctans: jax.vjp(g, *args)[1](ctans))(i, z))  # works

I'm still not sure how you manage to avoid the need for a transpose rule!

EDIT: this actually works correctly, I had a bug in my original example

shoyer commented 4 years ago

I see now that @apaszke's example works because JAX transposes the implementation of the multilinear function.

I suspect this will not suffice for @Jakob-Unfried's use-case, because it doesn't resolve the the fundamental issue that transpose(approx(multilinear_fun)) is not a good substitute for approx(transpose(multilinear_fun)). We need some way to customize the transpose rule, too.

danieldjohnson commented 4 years ago

Chiming in for a bit of context, as I understand it. Jax computes derivatives in two steps:

The approach described above (and custom_jvp in general) tells JAX how you want your function to be linearized. When you compute a jvp of it, it then just used this function you gave it instead of doing its normal linearization. When you use vjp, it uses your linearization but still transposes it using its own set of rules. (And in fact if you do custom_jvp but use nonlinear stuff in it, you will probably get some error about a missing transpose rule!)

There's a chance that the automatic transposition will do the right thing, but my guess is that it won't in general. That is @shoyer's point above, that transposing an approximation might not give the right approximate transpose (it might not work at all if you use anything nonlinear, or it might not be a good approximation if it works). If it doesn't work, the two existing options that I know about are custom_vjp (which lets you skip both the first and the second step and write f* directly) and implementing a custom primitive with a transpose rule (which is similar, but requires more knowledge of Jax internals; on the other hand it means that you would probably get forward mode, reverse mode, and higher-order derivatives for free).

I'm not sure there's a general way to automatically derive an approximate transpose from an approximate linear function; that seems hard to do in general, since it probably depends on exactly how your approximation works.

edit: One more detail that's relevant here is that, since a multilinear function is linear in each argument separately, it effectively has a different transpose with respect to each of its arguments. (You can almost see this in the @multilinear decorator, since it splits the single call to your function into n calls to your function; after that, each call will be transposed separately with respect to the single tangent argument.)

Jakob-Unfried commented 4 years ago

update from me:

i could not come up with a working transpose rule.

what @shoyer did in his notebook with the transpose rule for einsum is just not possible with my implementation. what is needed here is, in essence, "contract everything but leave one tensor/operand out, the uncontracted axes, that would normally have been contracted with that one tensor are the axes of the output of the transpose rule". if i could do that, i could use that to directly write down the vjp. that contraction (times the output cotangent) is the cotangent for that one tensor.

in the linked notebook, @shoyer could use einsum in the transpose rule, but with different index strings. my approximate function does not have that flexibility. it is like an approximate einsum with fixed (and only quite specifc) index strings .

i am quite certain that unless you can supply jax with an (approximate or exact) implementation for "contract everything except for one tensor", there is no chance of having either a transpose or a vjp rule. but if you can, both are simple

so it looks like i am left with the jvp and forward mode

Tanks for all the insights, i learned a lot from this thread