aesara-devs / aesara

Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.
https://aesara.readthedocs.io
Other
1.17k stars 156 forks source link

Add a `is_static_jax` property to TensorVariable's `tag` #182

Open junpenglao opened 3 years ago

junpenglao commented 3 years ago

Jax jit requires static inputs for some of the function args (for example, shape in jnp.reshape, length in jax.lax.scan). Currently, if these are symbolic input it will break jax.jit in https://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/sandbox/jax_linker.py#L80

I propose we add a property to TensorVariable in:

diff --git a/theano/tensor/var.py b/theano/tensor/var.py
index 4cda4e5e1..6f2aaf398 100644
--- a/theano/tensor/var.py
+++ b/theano/tensor/var.py
@@ -872,6 +872,8 @@ class TensorVariable(_tensor_py_operators, Variable):

                 pdb.set_trace()

+    def is_static_jax(self):
+        return False

 TensorType.Variable = TensorVariable

and SharedVariable

diff --git a/theano/compile/sharedvalue.py b/theano/compile/sharedvalue.py
index cc3dd3cce..ca3e7af3b 100644
--- a/theano/compile/sharedvalue.py
+++ b/theano/compile/sharedvalue.py
@@ -224,6 +224,9 @@ class SharedVariable(Variable):
     # We keep this just to raise an error
     value = property(_value_get, _value_set)

+    def is_static_jax(self):
+        return False
+

 def shared_constructor(ctor, remove=False):
     if remove:

Then we can detect the additional static_argnums in:

diff --git a/theano/sandbox/jax_linker.py b/theano/sandbox/jax_linker.py
index 59b61caf3..0093c3fa7 100644
--- a/theano/sandbox/jax_linker.py
+++ b/theano/sandbox/jax_linker.py
@@ -62,7 +62,9 @@ class JAXLinker(PerformLinker):
         # I suppose we can consider `Constant`s to be "static" according to
         # JAX.
         static_argnums = [
-            n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
+            n
+            for n, i in enumerate(self.fgraph.inputs)
+            if isinstance(i, Constant) or i.is_static_jax
         ]

         thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]

For user, they will need to mark these variable by hand for now, for example, we can do the following to make the tests pass:

diff --git a/tests/sandbox/test_jax.py b/tests/sandbox/test_jax.py
index 89c46ff9b..c3c3d7225 100644
--- a/tests/sandbox/test_jax.py
+++ b/tests/sandbox/test_jax.py
@@ -534,10 +534,10 @@ def test_jax_Reshape():
     compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])

-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
 def test_jax_Reshape_nonconcrete():
     a = tt.vector("a")
     b = tt.iscalar("b")
+    b.is_static_jax = True
     x = tt.basic.reshape(a, (b, b))
     x_fg = theano.gof.FunctionGraph([a, b], [x])
     compare_jax_and_py(
@@ -666,10 +666,10 @@ def test_tensor_basics():
     compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
 def test_arange_nonconcrete():

     a = tt.scalar("a")
+    a.is_static_jax = True
     a.tag.test_value = 10

     out = tt.arange(a)
@@ -677,7 +677,6 @@ def test_arange_nonconcrete():
     compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
brandonwillard commented 3 years ago

I thought static parameters were effectively Constants; is that not the correct indicator already?

Regarding shared values, those don't really exist within the JAX compilation context; they're Constants at that point. See #73 for an explanation.

Also, if I recall, static_argnums only applies to the inputs of the JITed function, which implies that such an indicator would only have relevance during the JAX transpilation process and nowhere else.

junpenglao commented 3 years ago

I thought static parameters were effectively Constants; is that not the correct indicator already?

I guess those are parameters not actually static parameters (thus they are not Constants), but nonetheless due to the limit of XLA you need to treat them as static at runtime. If we can mark them as static, it gives us a bit of additional flexibility to jit those function.

Also, if I recall, static_argnums only applies to the inputs of the JITed function, which implies that such an indicator would only have relevance during the JAX transpilation process and nowhere else.

Yes that's right.

brandonwillard commented 3 years ago

At a high level, we shouldn't add properties to classes unless they're directly relevant to the concepts/objects being modeled by the classes (e.g. a TensorVariable class should only concern itself with properties relating tensors and/or variables). This particular addition is far too specific to the JAX transpilation process, and, just like the numerous C-related methods attached to our classes, we can incorporate this information differently.

From a lower level, our class implementations need to remain as simple and "static" as possible. Doing so greatly improves the comprehensibility of our code, since it introduces fewer runtime and downstream logic surprises (e.g. avoiding questions like "What's this field, where did it come from, and how did it get set to this?"). Also, we could leverage some individually small—but cumulatively large—performance advantages from this situation (e.g. __slots__), especially as graphs scale. See issue #72 for related concerns.

That said, the tag field is better for this situation; however, it's still not clear to me how we would use this information. Can you give a small example of how/when we could use it to avoid the limitations in jnp.reshape and others?

junpenglao commented 3 years ago

That said, the tag field is better for this situation; however, it's still not clear to me how we would use this information. Can you give a small example of how/when we could use it to avoid the limitations in jnp.reshape and others?

Adding it to tag is a nice compromised. The way we could use it to avoid the limitation in jnp.reshape is by marking the shape arg tensors being jax_static, then theano.function(..., mode='jax') would also work (i.e., does not gives an error during jax.jit).

ricardoV94 commented 2 years ago

I came across something like this in #631. There is a second problem in that scalar symbolic variables become scalar numpy arrays during execution, and these cannot be used as static arguments for Jax functions, because they are not hashable.