Open martiningram opened 3 years ago
Martin, thanks for the kind words as always! It's really very encouraging.
I think we could revise the matrix_power
implementation not to rely on Python control flow. That is, we can replace uses of Python control flow with lax.switch
and lax.while_loop
, so that we'd be able to stage it out with jit
no problem. (If we use lax.while_loop
for memory efficiency, we'd probably need to define a custom jvp.)
Do you use reverse-mode differentiation through this?
Thanks for the super-fast response Matt! That sounds great. Ideally I was planning to do reverse-mode auto-diff, but I realise that while_loop
doesn't currently support that. I'm not sure how much slower forward-mode would be but I don't have a great number of parameters so it would probably be fine!
I think we can work out how to define a custom differentiation rule to make reverse-mode (and forward-mode) work efficiently. I mainly wanted to know if there was an easy way to get you un-stuck.
Can you use expm
together with some matrix analogue of x^n = exp(n * log x) ? Hmm seems that we don't have a matrix logarithm function...
Thanks Matt! Following your suggestion, I've made a while_loop
based version:
@jit
def matrix_power_while_inner(val, F):
i, cur_val = val
return i - 1, F @ cur_val
@jit
def matrix_power_while(F, n):
cond_fun = lambda val: val[0] >= 0
init_val = (n - 1, jnp.eye(F.shape[0]))
body_fun = lambda val: matrix_power_while_inner(val, F)
res = while_loop(cond_fun, body_fun, init_val)
return res[1]
# Returns True:
jnp.allclose(matrix_power_while(F, 10), jnp.linalg.matrix_power(F, 10))
I'll give that a go with forward mode for now. Let me know, I'd be happy to try to adapt this if it's a reasonable way to go for a new version of matrix_power
.
Is there an upper-bound on the exponent? We should probably just write something in terms of lax.scan
(together with lax.cond
for "early exit"), where the length is ceil(log2(upper_bound_on_exponent)). Even if the upper-bound is 2**32, I'm guessing you can tolerate storing 32 copies of your array. WDYT?
Hey Matt, there's definitely an upper bound which I know in advance, and it's certainly much smaller than 2**32
, probably less than 2**12
actually. Here's a new version:
import jax.numpy as jnp
from jax.lax import cond, scan
from jax import jit
from jax.numpy import divmod
n = 140
@jit
def scan_fun(carry, xs):
# One step of the iteration
n, z, result = carry
new_n, bit = divmod(n, 2)
new_result = cond(bit, lambda x: z @ x, lambda x: x, result)
# No more computation necessary if n = 0
# Is there a better way to early break rather than just returning something empty?
new_z = cond(new_n, lambda z: z @ z, lambda _: jnp.empty(z.shape), z)
return (new_n, new_z, new_result), None
@jit
def matrix_power_scan(F, n, upper_limit=32):
# TODO: I think we can avoid setting the third carry element to eye and save one matrix multiply
init_carry = n, F, jnp.eye(F.shape[0])
result = cond(n == 1, lambda _: F, lambda _: scan(scan_fun, init_carry, None, length=upper_limit)[0][2],
F)
return result
# Returns True
jnp.allclose(matrix_power_scan(F, n), jnp.linalg.matrix_power(F, n))
Thanks for the pointer. I don't think this is completely ideal (pretty sure it does one matrix multiply too many, see comments) but it seems to work and it'll be much faster than my hopelessly inefficient naive version, I should have thought of the log trick! Let me know if you have any thoughts and whether this makes sense.
Dear JAX team,
thanks for all the amazing work you're doing!
I'm using
jax.numpy.linalg.matrix_power
but am running into an issue when trying to use it withjit
. Here's a minimal example:The last line produces an error. Here's the full trace:
Is there some way around this? I suppose I could declare
n
a static_argnum, but this would be very inefficient for my application.Thanks!