brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
515 stars 93 forks source link

Standardizing and generalizing object-oriented transformations #628

Closed chaoming0625 closed 7 months ago

chaoming0625 commented 7 months ago

This PR standardizes the customization of object-oriented transformations. The key is using brainpy.math.VariableStack and brainpy.math.eval_shape.

One OO transformation involves two steps. The first step is using brainpy.math.eval_shape to evaluate all Variables used in the target function. The second step is the actual compilation phase, to compile the model on the given target device.

For example, to customize an object-oriented JIT compilation interface, we can use:


import jax
import brainpy.math as bm

def jit(fun):
  stack: bm.VariableStack = None
  jit_fun = None

  def new_fun(vars, *args, **kwargs):
    for k, v in vars.items():
        stack[k].value = v
    ret = fun(*args, **kwargs)
    new_vars = stack.dict_data()
    return ret, new_vars

  def wrapper(*args, **kwargs):
    global stack, jit_fun

    # [first step]: find all the variables
    if stack is None:
      with bm.VariableStack() as stack:
        ret = bm.eval_shape(fun, *args, **kwargs)
        jit_fun = jax.jit(new_fun)
      if not stack.is_first_stack():
        return ret

    # [second step]: jit compilation
    ret, new_vars = jit_fun(stack.dict_data(), *args, **kwargs)
    stack.assign(new_vars)
    return ret

  return wrapper