jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.33k stars 2.78k forks source link

jax.numpy.linalg.matrix_power does not play well with jit #4723

Open martiningram opened 3 years ago

martiningram commented 3 years ago

Dear JAX team,

thanks for all the amazing work you're doing!

I'm using jax.numpy.linalg.matrix_power but am running into an issue when trying to use it with jit. Here's a minimal example:

import jax.numpy as jnp

def minimal_example(F, x, n):

    return jnp.linalg.matrix_power(F, n) @ x @ jnp.linalg.matrix_power(F, n).T

sample_F = jnp.eye(2)
sample_x = jnp.array([2, 1])
sample_n = 2

# Works fine without JIT
minimal_example(sample_F, sample_x, sample_n)

jit_fun = jit(minimal_example)
jit_fun(sample_F, sample_x, sample_n)

The last line produces an error. Here's the full trace:

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-163-a3fbaea19733> in <module>
----> 1 jit_fun(sample_F, sample_x, sample_n)

<ipython-input-159-ae3772d83a8f> in minimal_example(F, x, n)
      2 
----> 3     return jnp.linalg.matrix_power(F, n) @ x @ jnp.linalg.matrix_power(F, n).T

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/numpy/linalg.py in matrix_power(a, n)
     75   except TypeError as err:
---> 76     raise TypeError("exponent must be an integer, got {}".format(n)) from err
     77 

FilteredStackTrace: TypeError: exponent must be an integer, got Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/numpy/linalg.py in matrix_power(a, n)
     73   try:
---> 74     n = operator.index(n)
     75   except TypeError as err:

TypeError: 'DynamicJaxprTracer' object cannot be interpreted as an integer

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
<ipython-input-163-a3fbaea19733> in <module>
----> 1 jit_fun(sample_F, sample_x, sample_n)

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    137   def reraise_with_filtered_traceback(*args, **kwargs):
    138     try:
--> 139       return fun(*args, **kwargs)
    140     except Exception as e:
    141       if not is_under_reraiser(e):

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    215         backend=backend,
    216         name=flat_fun.__name__,
--> 217         donated_invars=donated_invars)
    218     return tree_unflatten(out_tree(), out)
    219 

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1160 
   1161   def bind(self, fun, *args, **params):
-> 1162     return call_bind(self, fun, *args, **params)
   1163 
   1164   def process(self, trace, fun, tracers, params):

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1151   tracers = map(top_trace.full_raise, args)
   1152   with maybe_new_sublevel(top_trace):
-> 1153     outs = primitive.process(top_trace, fun, tracers, params)
   1154   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1155 

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1163 
   1164   def process(self, trace, fun, tracers, params):
-> 1165     return trace.process_call(self, fun, tracers, params)
   1166 
   1167   def post_process(self, trace, out_tracers, params):

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    573 
    574   def process_call(self, primitive, f, tracers, params):
--> 575     return primitive.impl(f, *tracers, **params)
    576   process_map = process_call
    577 

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    555 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
    556   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 557                                *unsafe_map(arg_spec, args))
    558   try:
    559     return compiled_fun(*args)

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    245       fun.populate_stores(stores)
    246     else:
--> 247       ans = call(fun, *args)
    248       cache[key] = (ans, fun.stores)
    249 

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    630   abstract_args, arg_devices = unzip2(arg_specs)
    631   if config.omnistaging_enabled:
--> 632     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
    633     if any(isinstance(c, core.Tracer) for c in consts):
    634       raise core.UnexpectedTracerError("Encountered an unexpected tracer.")

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
   1036     main.source_info = fun_sourceinfo(fun.f)  # type: ignore
   1037     main.jaxpr_stack = ()  # type: ignore
-> 1038     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1039     del main
   1040   return jaxpr, out_avals, consts

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1017     trace = DynamicJaxprTrace(main, core.cur_sublevel())
   1018     in_tracers = map(trace.new_arg, in_avals)
-> 1019     ans = fun.call_wrapped(*in_tracers)
   1020     out_tracers = map(trace.full_raise, ans)
   1021   jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    154 
    155     try:
--> 156       ans = self.f(*args, **dict(self.params, **kwargs))
    157     except:
    158       # Some transformations yield from inside context managers, so we have to

<ipython-input-159-ae3772d83a8f> in minimal_example(F, x, n)
      1 def minimal_example(F, x, n):
      2 
----> 3     return jnp.linalg.matrix_power(F, n) @ x @ jnp.linalg.matrix_power(F, n).T

~/miniconda3/envs/jax/lib/python3.7/site-packages/jax/numpy/linalg.py in matrix_power(a, n)
     74     n = operator.index(n)
     75   except TypeError as err:
---> 76     raise TypeError("exponent must be an integer, got {}".format(n)) from err
     77 
     78   if n == 0:

TypeError: exponent must be an integer, got Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

Is there some way around this? I suppose I could declare n a static_argnum, but this would be very inefficient for my application.

Thanks!

mattjj commented 3 years ago

Martin, thanks for the kind words as always! It's really very encouraging.

I think we could revise the matrix_power implementation not to rely on Python control flow. That is, we can replace uses of Python control flow with lax.switch and lax.while_loop, so that we'd be able to stage it out with jit no problem. (If we use lax.while_loop for memory efficiency, we'd probably need to define a custom jvp.)

mattjj commented 3 years ago

Do you use reverse-mode differentiation through this?

martiningram commented 3 years ago

Thanks for the super-fast response Matt! That sounds great. Ideally I was planning to do reverse-mode auto-diff, but I realise that while_loop doesn't currently support that. I'm not sure how much slower forward-mode would be but I don't have a great number of parameters so it would probably be fine!

mattjj commented 3 years ago

I think we can work out how to define a custom differentiation rule to make reverse-mode (and forward-mode) work efficiently. I mainly wanted to know if there was an easy way to get you un-stuck.

Can you use expm together with some matrix analogue of x^n = exp(n * log x) ? Hmm seems that we don't have a matrix logarithm function...

martiningram commented 3 years ago

Thanks Matt! Following your suggestion, I've made a while_loop based version:

@jit
def matrix_power_while_inner(val, F):

    i, cur_val = val

    return i - 1, F @ cur_val

@jit
def matrix_power_while(F, n):

    cond_fun = lambda val: val[0] >= 0
    init_val = (n - 1, jnp.eye(F.shape[0]))
    body_fun = lambda val: matrix_power_while_inner(val, F)

    res = while_loop(cond_fun, body_fun, init_val)

    return res[1]

# Returns True:
jnp.allclose(matrix_power_while(F, 10), jnp.linalg.matrix_power(F, 10))

I'll give that a go with forward mode for now. Let me know, I'd be happy to try to adapt this if it's a reasonable way to go for a new version of matrix_power.

mattjj commented 3 years ago

Is there an upper-bound on the exponent? We should probably just write something in terms of lax.scan (together with lax.cond for "early exit"), where the length is ceil(log2(upper_bound_on_exponent)). Even if the upper-bound is 2**32, I'm guessing you can tolerate storing 32 copies of your array. WDYT?

martiningram commented 3 years ago

Hey Matt, there's definitely an upper bound which I know in advance, and it's certainly much smaller than 2**32, probably less than 2**12 actually. Here's a new version:

import jax.numpy as jnp
from jax.lax import cond, scan
from jax import jit
from jax.numpy import divmod

n = 140

@jit
def scan_fun(carry, xs):

    # One step of the iteration
    n, z, result = carry
    new_n, bit = divmod(n, 2)

    new_result = cond(bit, lambda x: z @ x, lambda x: x, result)

    # No more computation necessary if n = 0
    # Is there a better way to early break rather than just returning something empty?
    new_z = cond(new_n, lambda z: z @ z, lambda _: jnp.empty(z.shape), z)

    return (new_n, new_z, new_result), None

@jit
def matrix_power_scan(F, n, upper_limit=32):

    # TODO: I think we can avoid setting the third carry element to eye and save one matrix multiply
    init_carry = n, F, jnp.eye(F.shape[0])

    result = cond(n == 1, lambda _: F, lambda _: scan(scan_fun, init_carry, None, length=upper_limit)[0][2],
                  F)

    return result

# Returns True
jnp.allclose(matrix_power_scan(F, n), jnp.linalg.matrix_power(F, n))

Thanks for the pointer. I don't think this is completely ideal (pretty sure it does one matrix multiply too many, see comments) but it seems to work and it'll be much faster than my hopelessly inefficient naive version, I should have thought of the log trick! Let me know if you have any thoughts and whether this makes sense.