brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
493 stars 90 forks source link

Convert a BrainPy model to process batched input by `jax.vmap` #608

Open CloudyDory opened 5 months ago

CloudyDory commented 5 months ago

For a model written to process single input data, is it possible to convert the model to process batched input data simply by using jax.vmap? Or do we have to re-write the model to process batched data?

The code section looks like this:

# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())

def step_run(i, x_single):
    '''
    Inputs:
        x_single: [height, width]
    '''
    x = bm.where(bm.logical_and(cfg['stim_start_timepoint']<=i, i<cfg['stim_end_timepoint']), x_single, blank_img)
    out = model.step_run(i, x)  # [n_neuron]
    return out

def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [height, width]
        y_single: [1]
    '''
    model.reset_state() 
    indices = np.arange(cfg['total_timepoint'])  # sequence length
    spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), indices)  # [length, n_neuron]
    frate_out = bm.sum(spike_out, axis=0) + 1.0e-6  # [n_neuron]

    predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0)  # log-prababilities, [batch=1, n_neuron]
    loss = bp.losses.nll_loss(-predicts, y_single)  # scalar, Need to manually add a negative sign because BrainPy does not do so. scalar
    acc = bm.mean(predicts.argmax(-1) == y_single)  # scalar
    return loss, acc

grad_f = jax.vmap(bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True))

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, height, width]
        y_batch: [batch, 1]
    '''
    train_vars = model.train_vars().unique()

    grads, losses, acces = grad_f(x_batch, y_batch)  # PyTree of gradients, [batch], [batch]
    grads_mean = jax.tree_map(lambda x: bm.sum(x, axis=0), grads)

    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    opt.update(grads_mean)

    return loss, acc

It currently raises the following error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[50000] wrapped in a BatchTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.

I found a previous issue (#206) mentioning this. Is it still not possible to use jax.vmap with brainpy models?

chaoming0625 commented 5 months ago

Thanks for opening this great question. Actually, the object-oriented style in BrainPy does not support a general mapping transformation with vmap and pmap. But we can easily customize our mapping for a specific problem. Here i will give you an example.

chaoming0625 commented 5 months ago

The key of BrainPy's Variable system is to find out all variables used in the objects and then transform this object into a function so that it can be compiled by JAX's functional transformations. Existing brainpy transformations like brainpy.math.jit, brainpy.math.scan have already hidden these processes. However, for a new transformation, users can also follow such two steps.

In your case, you want to vmap the gradient function to get the batched gradients. So, all weights can not be batched, all states or variables should be batched, and the outputs should also be batched. Therefore, we can customize this transformation as:

import jax

import brainpy.math as bm
from functools import wraps

def vmap_grad_fun(f, *inputs):
  # Step 1: finding out all variables #
  # --------------------------------- #

  # evaluation without spending any actual FLOP computation
  vars, _ = bm.eval_shape(f, *inputs)

  # separate variables into two groups: weights and states
  weights, states = vars.separate_by_instance(bm.TrainVar)

  # Step 2: transform the object as the function that compatible with jax.vmap #
  # -------------------------------------------------------------------------- #

  @wraps(f)
  def new_fun(ws, vars, inputs):
    # A. assign weights and states in each batch to the model
    for key in ws: weights[key] = ws[key]
    for key in vars: states[key] = vars[key]

    # B. run the function
    outputs = f(*inputs)

    # C. return outputs of each batch
    return outputs

  ori_weights, ori_states = weights.dict_data(), vars.dict_data()
  # replicate the states for batching
  batch_size = inputs[0].shape[0]
  batched_states = jax.tree_map(lambda x: bm.repeat(bm.expand_dims(x, 0), batch_size, axis=0), ori_states)

  # batching the states and inputs
  batched_outs = jax.vmap(new_fun, in_axes=(None, 0, 0), out_axes=0)(ori_weights, batched_states, inputs)
  del batched_states

  # recovery the origin weights and states
  for key in ori_weights: weights[key] = ori_weights[key]
  for key in ori_states: vars[key] = ori_states[key]

  # Step 3: return the batched outputs
  return batched_outs

I hope this example can help you achieve the desired transformation.