aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.18k stars 154 forks source link

Apply JAX JIT more selectively #684

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

It seems JAX JIT is fundamentally more limited in the types of graphs that it can handle compared to Aesara's C and Numba backends, as described in https://github.com/aesara-devs/aesara/issues/43, https://github.com/aesara-devs/aesara/issues/68, and a couple of other closed issues. This mostly relates to handling of dynamic shapes.

Right now we attempt to JIT the entire graph when using the JAX backend via Aesara, which can often fail. One more flexible alternative (with potential bad performance) would be to only JIT the subgraphs that are compatible with the JAX limitations, trying to always JIT at the highest level of the graph. This would probably require some meta-data on Ops and possible compile time graph analysis to figure out if some of the computations can be safely "concretized"/done with numpy API, which sounds like a more specialized form of constant folding.

We might also consider JAX specific rewrites that lift reshaping operations closer to the inputs so that higher portions of the graph can be Jitted. I have no idea if those opportunities actually exist.

It sounds like #431 would also help accommodating more types of graphs produced by Aesara (also related to #182).

rlouf commented 2 years ago

Another option is to let users jit functions themselves. It is more flexible for them, less of a hassle library-side.

brandonwillard commented 2 years ago

Another option is to let users jit functions themselves. It is more flexible for them, less of a hassle library-side.

In other words, do no JAX transpilation whatsoever?

ricardoV94 commented 2 years ago

You can transpile without JIT, but that's basically numpy?

brandonwillard commented 2 years ago

You can transpile without JIT, but that's basically numpy?

Yes. More specifically, one can use their own JAXLinker subclass and override JAXLinker.jit_compile so that it does nothing.

Unfortunately, none of that will help solve the problem(s) addressed by this issue.

brandonwillard commented 2 years ago

By the way, the idea of selectively transpiling only certain subgraphs was a part of the original implementation and it was already discussed. As the proposed solutions have implied, selective transpilation severely complicated things and leads to many open choices that could easily make or break the utility of transpilation in the first place. If the scope of our work and considerations was more restricted (e.g. only JAX compilation or standard Python evaluation), it might be feasible, but it's not.

Use of JAX-specific rewrites (e.g. to remove/avoid non-transpilable sub-graphs) is reasonable, though.

twiecki commented 2 years ago

Yeah, I would raise an error that the Op is not supported which would motivate a workaround in user code or aesara.

rlouf commented 2 years ago

You can transpile without JIT, but that's basically numpy?

No it's not since you have the possibility to jit-compile the resulting function, and thus specify the static_argnums yourself. Many jax libraries do not jit-compile functions for users because of the issues mentioned elsewhere on this repo.

But I see this has already been discussed, and I understand the conclusion, but you will have to make choices for users in that case.

ricardoV94 commented 2 years ago

We can definitely leave the jitting up to the users

rlouf commented 2 years ago

From what I've seen in the issue tracker this can be broken down into two separate problems:

  1. Shapes whose value is known at compile time by Aesara. In this case, as discussed in #702 the solution is to not convert np.ndarrays to DeviceArrays during transpilation.
import jax
import numpy as np

shape = np.array([10])  # currently jnp.array([10])

def jax_funcified(prng_key):
    return jax.random.normal(prng_key, shape)

key = jax.random.PRNGKey(0)
print(jax.jit(jax_funcified)(key))
#   [-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442 
# -0.67135346 -0.5908641   0.73168886  0.5673026 ]
  1. Shapes that Aesara does not know at compile time. It is not easy to predict what JAX will do, and we (at least, I) need to better understand what happens under the hood. Indeed, the following works and Aesara should transpile:
import jax
import jax.numpy as jnp
import numpy as np

def b(z):
    return z.shape

def a(x, z):
    return jnp.reshape(x, b(z))

x = np.zeros(6)
z = np.zeros((2,3))
jax_res = jax.jit(a)(x, z)
print(jax_res)
# [[0. 0. 0.]
#  [0. 0. 0.]]

but the following doesn't without static_argnums:

import jax
import jax.numpy as jnp
import numpy as np

def b(z):
    return z

def a(x, z):
    return jnp.reshape(x, b(z))

x = np.zeros(6)
z = (2,3)
try:
    jax_res = jax.jit(a)(x, z)
except Exception as e:
    print(e)
# Shapes must be 1D sequences of concrete values of integer type,
# got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>).

jax_res = jax.jit(a, static_argnums=(1,))(x, z)
print(jax_res)
# [[0. 0. 0.]
#  [0. 0. 0.]]

It looks like JAX eagerly converts all flattened inputs to traced arrays. Aesara should transpile here too. This is actually a good example of why we should not JIT-compile functions for the user. Indeed, JIT-compilation of the above fails without static_argnums, but the folllowing succeeds:

import jax
import jax.numpy as jnp
import numpy as np

def b(z):
    return z

def a(x, z):
    return jnp.reshape(x, b(z))

def c(x, y):
   z = y.shape
   return a(x, z)

x = np.zeros(6)
z = np.zeros((2,3))
jax_res = jax.jit(c)(x, z)
print(jax_res)
# [[0. 0. 0.]
#  [0. 0. 0.]]

We don't know what users intend to do with their functions in this case (assuming the goal of transpilation is use with the wider JAX ecosystem), and we may limit users with no good reasons.

Sometimes (when?) Aesara may be able to determine at compile time that even if static_argnums is specfied, and whatever the inputs, JIT-compilation will fail. In this case as, as discussed in https://github.com/aesara-devs/aesara/discussions/1184 a solution is to fail graciously during transpilation with an explanation.

Anyway, we can be smarter by JAX but we should not try to compensate for its limitations in this backend, but by creating a new one that targets XLA directly, for instance.