pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
https://pytensor.readthedocs.io
Other
300 stars 91 forks source link

No JAX dispatch for `mul_without_zeros` #526

Open jessegrabowski opened 7 months ago

jessegrabowski commented 7 months ago

Describe the issue:

The ProdWithoutZeros Op arises in the gradients of pt.prod. This currently cannot be compiled to gradient mode unless we specifically pass no_zeros_in_input=True. I guess we would just need a JAX dispatch for this function? Or maybe a mapping to the correct jax.lax function?

Reproducable code example:

import pytensor
import pytensor.tensor as pt
x = pt.dvector('x')
z = pt.prod(x, no_zeros_in_input=False)
gz = pytensor.grad(z, x)

f_gz = pytensor.function([x], gz, mode='JAX')
f_gz([1, 2, 3, 4])

Error message:

```shell --------------------------------------------------------------------------- AttributeError Traceback (most recent call last) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline..streamline_default_f() 195 for thunk, node, old_storage in zip( 196 thunks, order, post_thunk_old_storage 197 ): --> 198 thunk() 199 for old_s in old_storage: File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk..thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs) 654 def thunk( 655 fgraph=self.fgraph, 656 fgraph_jit=fgraph_jit, 657 thunk_inputs=thunk_inputs, 658 thunk_outputs=thunk_outputs, 659 ): --> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) 662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): [... skipping hidden 12 frame] File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x) 12 # ProdWithoutZeros{axes=None}(Mul.0) ---> 13 tensor_variable_5 = careduce_1(tensor_variable_4) 14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce..careduce(x) 54 if to_reduce: 55 # In this case, we need to use the `jax.lax` function (if there 56 # is one), and not the `jnp` version. ---> 57 jax_op = getattr(jax.lax, scalar_fn_name) 58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr..getattr(name) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros' During handling of the above exception, another exception occurred: AttributeError Traceback (most recent call last) Cell In[61], line 1 ----> 1 f_z([1, 2, 3, 4]) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs) 967 t0_fn = time.perf_counter() 968 try: 969 outputs = ( --> 970 self.vm() 971 if output_subset is None 972 else self.vm(output_subset=output_subset) 973 ) 974 except Exception: 975 restore_defaults() File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline..streamline_default_f() 200 old_s[0] = None 201 except Exception: --> 202 raise_with_op(fgraph, node, thunk) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:531, in raise_with_op(fgraph, node, thunk, exc_info, storage_map) 526 warnings.warn( 527 f"{exc_type} error does not allow us to add an extra error message" 528 ) 529 # Some exception need extra parameter in inputs. So forget the 530 # extra long error message in that case. --> 531 raise exc_value.with_traceback(exc_trace) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline..streamline_default_f() 194 try: 195 for thunk, node, old_storage in zip( 196 thunks, order, post_thunk_old_storage 197 ): --> 198 thunk() 199 for old_s in old_storage: 200 old_s[0] = None File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk..thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs) 654 def thunk( 655 fgraph=self.fgraph, 656 fgraph_jit=fgraph_jit, 657 thunk_inputs=thunk_inputs, 658 thunk_outputs=thunk_outputs, 659 ): --> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) 662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): 663 compute_map[o_var][0] = True [... skipping hidden 12 frame] File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x) 11 tensor_variable_4 = elemwise_fn_2(tensor_variable_3, x) 12 # ProdWithoutZeros{axes=None}(Mul.0) ---> 13 tensor_variable_5 = careduce_1(tensor_variable_4) 14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0) 15 tensor_variable_6 = dimshuffle_1(tensor_variable_5) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce..careduce(x) 52 to_reduce = sorted(axis, reverse=True) 54 if to_reduce: 55 # In this case, we need to use the `jax.lax` function (if there 56 # is one), and not the `jnp` version. ---> 57 jax_op = getattr(jax.lax, scalar_fn_name) 58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype) 59 return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr..getattr(name) 51 warnings.warn(message, DeprecationWarning, stacklevel=2) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros' Apply node that caused the error: Switch(Eq.0, True_div.0, Switch.0) Toposort index: 13 Inputs types: [TensorType(bool, shape=(1,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,))] Inputs shapes: [(4,)] Inputs strides: [(8,)] Inputs values: [array([1., 2., 3., 4.])] Outputs clients: [['output']] Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer): File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3488, in run_ast_nodes if await self.run_code(code, result, async_=asy): File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3548, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_27218/3109327815.py", line 5, in gz = pytensor.grad(z, x) File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 607, in grad _rval: Sequence[Variable] = _populate_grad_dict( File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in _populate_grad_dict rval = [access_grad_cache(elem) for elem in wrt] File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in rval = [access_grad_cache(elem) for elem in wrt] File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1362, in access_grad_cache term = access_term_cache(node)[idx] File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1192, in access_term_cache input_grads = node.op.L_op(inputs, node.outputs, new_output_grads) HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node. ```

PyTensor version information:

Pytensor 2.17.4

Context for the issue:

I want the gradient of a product in JAX mode

ricardoV94 commented 7 months ago

I took out the bug label since it's more like a missing feature.

Is ProdWithoutZeros just pt.prod(x[pt.neq(x, 0)])? In that case I would suggest dispatching to something like than in JAX. jax.numpy.prod even accepts a where argument already: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.prod.html

Although I am not sure it's gonna like that dynamic boolean array. Maybe It needs to be implemented as a Scan?

ricardoV94 commented 7 months ago

The graph JAX produces for grad of Prod is absurd? They just enumerate all the cases where a zero might be (with some bisect logic)?

import jax
import jax.numpy as jnp

def prod(x):
    return jnp.prod(x)

# @jax.jit
def foo(x):
    return jax.grad(prod)(x)

jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))
jessegrabowski commented 7 months ago

Maybe rewrite to something like this?

import jax
import jax.numpy as jnp

def prod(x):
    return jnp.exp(jnp.sum(jnp.log(x)))

# @jax.jit
def foo(x):
    return jax.grad(prod)(x)

jax.make_jaxpr(foo)(jnp.arange(800, dtype="float32"))

Amusingly, this was suggested as a workaround here before the actual thing we have now was implemented.

ricardoV94 commented 6 months ago

That would fail with negative inputs, no?

ricardoV94 commented 6 months ago

Would prod(eq(x, 0), 1, x) work?

jessegrabowski commented 6 months ago

Isn't this exactly the jax graph but without the bisect logic to avoid checking every single value in x elemwise?

ricardoV94 commented 6 months ago

I don't see anywhere were jax is checking for zeros

ricardoV94 commented 6 months ago

The PR that implemented the grad is here: https://github.com/google/jax/pull/675/files

It's a bit odd, is that really better than just a switch statement?

ricardoV94 commented 6 months ago

In any case you are asking to rewrite a gradient, and unfortunately because grads are eager you don't know if this MulWithZeros is due to a grad of a prod or whatever other reason. So to be safe you would need to implement that exact Op dispatch.