jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

The linear_transpose of lax.scan gives an error #6619

Open inailuig opened 3 years ago

inailuig commented 3 years ago

first some motivation:

I have a hermitian linear operator which is using a jvp followed by a vjp of complex (shallow) neural networks which ~ looks like this:

def matvec(apply_fun, params, samples, v):
    _, w = jax.jvp(lambda p: apply_fun(p, samples), (params,), (v,))
    w = w.conj()
    res, _ = jax.vjp(apply_fun, params, samples)[1](w)
    return jax.tree_map(jax.lax.conj, res)

It works fine, however storing the forward pass of the vjp uses too much memory when using many parameters and samples (think of both like ~1M). However one can to split the vjp into smaller vjp's looping over batches of samples and summing the result (which works surprisingly well). See the snipped below which should make this clear.

I am using this linear operator with the cg and gmres iterative solvers, where its required to be linear_transpose'able. I have implemented the loop with lax.scan to try to make it transposable, however the _scan_transpose rule is giving an error which I dont't really know how to solve.

Here is a minimal snipped of code (where for simplicity I have replaced the jvp/vjp in the linear operator with products with the matrix J), which reproduces the error I am getting:

import jax
import jax.numpy as jnp
from functools import partial

n = 100
batchsize = 5
d = 10

# generate a random jacobian and a vector
k0, k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 4)
J = jax.random.uniform(k0, shape=(n,d)) + 1j * jax.random.uniform(k1, shape=(n,d))
v = jax.random.uniform(k2, shape=(d,)) + 1j * jax.random.normal(k3, shape=(d,))

# the linar operator
def mv(J, v):
    return J.T.conj()@(J@v)

# works fine
_ = jax.scipy.sparse.linalg.cg(partial(mv, J), v)
_ = jax.scipy.sparse.linalg.gmres(partial(mv, J), v, solve_method='incremental')

# mv is hermitian, so its transpose is trivial
def mv_transposed(J, v):
    return mv(J, v.conj()).conj()

# now the operator where the vjp is looping over small batches
def mv_batched(J, v):    
    w = J@v

    J_batched = J.reshape((-1, batchsize)+J.shape[1:])
    w_batched = w.reshape((-1, batchsize)+w.shape[1:])
    Jw = J_batched, w_batched

    def _mv(Jw):
        J, w = Jw
        return (w.conj() @ J).conj() 

    def f(carry, x):
        return carry + _mv(x), None

    res, _ = jax.lax.scan(f, jnp.zeros_like(v), Jw, unroll=1)

    return res

assert jnp.allclose(mv(J,v), mv_batched(J,v))
assert jnp.allclose(mv_transposed(J,v), jax.linear_transpose(partial(mv, J), v)(v)[0])

# linear_transposing mv_batched gives the error:
print( jax.linear_transpose(partial(mv_batched, J), v)(v)[0] )

# goal: use mv_batched with the iterative solvers
# same error:
# print( jax.scipy.sparse.linalg.cg(partial(mv_batched, J), v))
# print( jax.scipy.sparse.linalg.gmres(partial(mv_batched, J), v, solve_method='incremental'))

And the output:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-1-dbeaca6d614d> in <module>
     45 # jax.scipy.sparse.linalg.gmres(partial(mv_batched, J), v, solve_method='incremental')
     46 
---> 47 print( jax.linear_transpose(partial(mv_batched, J), v)(v)[0] )

/usr/local/lib/python3.9/site-packages/jax/api.py in transposed_fun(out_cotangent)
   1957     in_cotangents = map(
   1958         ad.instantiate_zeros,
-> 1959         ad.backward_pass(jaxpr, consts, dummies, out_cotangents))
   1960     return tree_unflatten(in_tree, in_cotangents)
   1961 

/usr/local/lib/python3.9/site-packages/jax/interpreters/ad.py in backward_pass(jaxpr, consts, primals_in, cotangents_in)
    216             params, call_jaxpr, invals, cts_in, cts_in_avals)
    217       else:
--> 218         cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
    219                                                          **eqn.params)
    220     cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out

/usr/local/lib/python3.9/site-packages/jax/_src/lax/control_flow.py in _scan_transpose(cts, reverse, length, num_consts, num_carry, jaxpr, linear, unroll, *args)
   1680   _, eres = split_list(xs, [sum(xs_lin)])
   1681   assert not any(ad.is_undefined_primal(r) for r in ires)
-> 1682   assert not any(ad.is_undefined_primal(r) for r in eres)
   1683 
   1684   carry_avals, y_avals = split_list(jaxpr.out_avals, [num_carry])

AssertionError: 

[I thought of just using linear_call, however it's still missing a XLA translation rule so it currently can't be jitted]

Please see this more like a feature request for a more universal transpose rule (unless there's a trivial trick to make it transposable which I'm missing; otherwise I would also be happy to have a go at trying to relax the requirement of the linear operator being transposable in the solvers if you point me in the right direction since there are several possible ways to do this).

froystig commented 3 years ago

Whenever a scan primitive is bound via lax.scan(f, ...), the scanned function f is assumed to be non-linear in all of its inputs. Only in transformation rules (such as the scan JVP rule) are scan primitives bound with flags indicating linearity.

Knowing that, we could reproduce this error more minimally by calling linear_transpose on any call to lax.scan:

from jax import lax, numpy as jnp, linear_transpose as tr

def csum(x):
  zero = jnp.zeros(x.shape[1:])
  return lax.scan(lambda c, x: (c + x, c + x), zero, x)[1]

x = jnp.arange(5.)
print(tr(csum, x)(x))

Some possible ways forward: