This PR standardizes the customization of object-oriented transformations. The key is using brainpy.math.VariableStack and brainpy.math.eval_shape.
One OO transformation involves two steps. The first step is using brainpy.math.eval_shape to evaluate all Variables used in the target function. The second step is the actual compilation phase, to compile the model on the given target device.
For example, to customize an object-oriented JIT compilation interface, we can use:
import jax
import brainpy.math as bm
def jit(fun):
stack: bm.VariableStack = None
jit_fun = None
def new_fun(vars, *args, **kwargs):
for k, v in vars.items():
stack[k].value = v
ret = fun(*args, **kwargs)
new_vars = stack.dict_data()
return ret, new_vars
def wrapper(*args, **kwargs):
global stack, jit_fun
# [first step]: find all the variables
if stack is None:
with bm.VariableStack() as stack:
ret = bm.eval_shape(fun, *args, **kwargs)
jit_fun = jax.jit(new_fun)
if not stack.is_first_stack():
return ret
# [second step]: jit compilation
ret, new_vars = jit_fun(stack.dict_data(), *args, **kwargs)
stack.assign(new_vars)
return ret
return wrapper
This PR standardizes the customization of object-oriented transformations. The key is using
brainpy.math.VariableStack
andbrainpy.math.eval_shape
.One OO transformation involves two steps. The first step is using
brainpy.math.eval_shape
to evaluate allVariable
s used in the target function. The second step is the actual compilation phase, to compile the model on the given target device.For example, to customize an object-oriented JIT compilation interface, we can use: