Open CloudyDory opened 5 months ago
Thanks for opening this great question. Actually, the object-oriented style in BrainPy does not support a general mapping transformation with vmap
and pmap
. But we can easily customize our mapping for a specific problem. Here i will give you an example.
The key of BrainPy's Variable
system is to find out all variables used in the objects and then transform this object into a function so that it can be compiled by JAX's functional transformations. Existing brainpy transformations like brainpy.math.jit
, brainpy.math.scan
have already hidden these processes. However, for a new transformation, users can also follow such two steps.
In your case, you want to vmap
the gradient function to get the batched gradients. So, all weights
can not be batched, all states
or variables
should be batched, and the outputs should also be batched. Therefore, we can customize this transformation as:
import jax
import brainpy.math as bm
from functools import wraps
def vmap_grad_fun(f, *inputs):
# Step 1: finding out all variables #
# --------------------------------- #
# evaluation without spending any actual FLOP computation
vars, _ = bm.eval_shape(f, *inputs)
# separate variables into two groups: weights and states
weights, states = vars.separate_by_instance(bm.TrainVar)
# Step 2: transform the object as the function that compatible with jax.vmap #
# -------------------------------------------------------------------------- #
@wraps(f)
def new_fun(ws, vars, inputs):
# A. assign weights and states in each batch to the model
for key in ws: weights[key] = ws[key]
for key in vars: states[key] = vars[key]
# B. run the function
outputs = f(*inputs)
# C. return outputs of each batch
return outputs
ori_weights, ori_states = weights.dict_data(), vars.dict_data()
# replicate the states for batching
batch_size = inputs[0].shape[0]
batched_states = jax.tree_map(lambda x: bm.repeat(bm.expand_dims(x, 0), batch_size, axis=0), ori_states)
# batching the states and inputs
batched_outs = jax.vmap(new_fun, in_axes=(None, 0, 0), out_axes=0)(ori_weights, batched_states, inputs)
del batched_states
# recovery the origin weights and states
for key in ori_weights: weights[key] = ori_weights[key]
for key in ori_states: vars[key] = ori_states[key]
# Step 3: return the batched outputs
return batched_outs
I hope this example can help you achieve the desired transformation.
For a model written to process single input data, is it possible to convert the model to process batched input data simply by using
jax.vmap
? Or do we have to re-write the model to process batched data?The code section looks like this:
It currently raises the following error:
I found a previous issue (#206) mentioning this. Is it still not possible to use
jax.vmap
with brainpy models?