pymc-devs / pytensor

PyTensor allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays.
297 stars 91 forks source link

Provide lower level Numba and Jax functions #222

Open ricardoV94 opened 1 year ago

ricardoV94 commented 1 year ago


pytensor.function() returns a class with a complicated __call__ method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).

We could provide a simpler jax_function and numba_function that do just that. In PyMC we implemented something like that for JAX:

There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.

A simple JAX/Numba PyTensor function with global variables and updates looks like this:

import pytensor
import pytensor.tensor as pt
import numpy as np

shared_y = pytensor.shared(np.ones((5,)))
x = pt.vector("x")
fn = pytensor.function([x], x + y, updates={y: y + 1}, mode="JAX")

And roughly translates to the following pseudo code:

global shared_y = np.ones((5,))

def fn(x):
  def inner_fn(x, y):
    return x + y, y + 1

  global shared_y
  out, update_y = inner_fn(x, shared_y)
  shared_y[:] = update_y
  return out

I don't think neither JAX nor Numba support stateful jitted functions, so users would need to work with the inner_fn directly.

The proposal here is to give users easy access to the compiled (jitted or not) inner_fn

ammar-s847 commented 1 year ago

Hey, is this still available to contribute to? Would love to get started!

twiecki commented 1 year ago
