jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.54k stars 2.81k forks source link

Dynamic sizes in dynamic programming #1859

Open srush opened 4 years ago

srush commented 4 years ago

Hi Jax,

Could you give me advice for the following setting. I am in a dynamic programming loop (cky) and need to construct dynamic_slice's. I read https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html and I know this is bad, however, this innerloop is exactly the part of the code I want to be fast. JIT compiling the outer function will require re-jit'ing for every static size change (N = 59 versus 60). However those will share 59 inner loops. Is there anyway to cache and share the jit'd inner loop with all 59 static instantiations?

     def update(w, beta):
            Y = jax.lax.dynamic_slice(beta[A], [0, 0], [N-w, w])
            Z = jax.lax.dynamic_slice(beta[B], [w, N - w], [N-w, w])
            ... 
            return (beta[A], beta[B])
      jax.lax.fori_loop(1, N, update, beta)

@mattjj @dougalm

mattjj commented 4 years ago

Thanks for opening this, @srush !

Just a quick note: I followed up with the XLA team, and it turns out that being able to lower to a single shape-polymorphic XLA program is much closer than I realized.

I'll update here as I learn more!

nickbhat commented 4 years ago

Hi, I'm curious if there's any news on this front @mattjj. This is a feature I'd be very interested in as well!