pytensor.function() returns a class with a complicated __call__ method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).
There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.
A simple JAX/Numba PyTensor function with global variables and updates looks like this:
import pytensor
import pytensor.tensor as pt
import numpy as np
shared_y = pytensor.shared(np.ones((5,)))
x = pt.vector("x")
fn = pytensor.function([x], x + y, updates={y: y + 1}, mode="JAX")
And roughly translates to the following pseudo code:
global shared_y = np.ones((5,))
def fn(x):
@jax.jit
def inner_fn(x, y):
return x + y, y + 1
global shared_y
out, update_y = inner_fn(x, shared_y)
shared_y[:] = update_y
return out
I don't think neither JAX nor Numba support stateful jitted functions, so users would need to work with the inner_fn directly.
Description
pytensor.function()
returns a class with a complicated__call__
method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).We could provide a simpler
jax_function
andnumba_function
that do just that. In PyMC we implemented something like that for JAX: https://github.com/pymc-devs/pymc/blob/31c30dc1beea26e4bff52a93037540923feaaa84/pymc/sampling/jax.py#L108-L132There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.
A simple JAX/Numba PyTensor function with global variables and updates looks like this:
And roughly translates to the following pseudo code:
I don't think neither JAX nor Numba support stateful jitted functions, so users would need to work with the
inner_fn
directly.https://numba.pydata.org/numba-doc/dev/user/faq.html#numba-doesn-t-seem-to-care-when-i-modify-a-global-variable
The proposal here is to give users easy access to the compiled (jitted or not)
inner_fn