brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
493 stars 90 forks source link

Using custom iterable or functional input in DSTrainer #602

Closed CloudyDory closed 5 months ago

CloudyDory commented 5 months ago

When simulating a network, sometimes it is not convenient to pre-generate the entire input data due to memory constraints. If we use DSRunner, we can supply a iterable or functional input when initializing the runner, so the input data can be stored in a compressed form or generated at runtime. However, it seems that this is not possible when using DSTrainer. The DSTrainer object requires supplying the entire input data when calling the predict() function. Is it possible to overcome this issue?

chaoming0625 commented 5 months ago

Actually, they are all using brainpy.math.for_loop. Therefore, DSTrainer can also support the generation of data at runtime. But I recommend to use brainpy.math.for_loop or brainpy.math.jit instead.

For example, if your data is organized as:


def your_data_at_each_time_step():
   return # something

def step_run(i):
   inp = your_data_at_each_time_step()
   return model.step_run(i, inp)

outs = bm.for_loop(step_run, np.arange(1000))

However, if your data is generated or loaded from the external source, you can use jit:


@bm.jit
def step_run(i, inp):
  return model.step_run(i, inp)

for i in np.arange(10000):
   # load data
   inp = ....
   out = step_run(i, inp)
CloudyDory commented 5 months ago

Probably I need to provide a more complete picture of the problem I meet here.

The training process is:

  1. Extract a batch of training images with dimenison [batch, height, width] and the corresponding class label with dimension [batch].
  2. For each image in the batch, extend it into a long image sequence with dimenion [length, height, width], where the sequences are the same training image for idx1 <= i < idx2, and zeros for 0 <= i < idx1 or idx2 <= i < length.
  3. Feed the image sequence to the model, count the output spikes in duration length.
  4. Accumulate the loss and gradient for this single training image.
  5. Perform one step of gradient descent when all images in the batch are processed.

The problem is in step 2, length can be quite large, and the extended image sequence can occupy too much memory. Therefore, I have created a customed indexable class to store the sequence in a compressed form. It only stores the two unique images in the sequence, and select one as output given the query index.

from collections.abc import Sequence

# Iterator object to efficient store the input image
class CompressedImage(Sequence):
    def __init__(self, img1, img2, img2_idx_range, length):
        self.i = 0
        self.img1 = img1
        self.img2 = img2
        self.img2_idx_range = img2_idx_range
        self.length = length  # in number of time points
        self.shape = (length,) + img1.shape

    def __iter__(self):
        return self

    def __next__(self):
        if self.i < self.length:
            if self.i<self.img2_idx_range[0] or self.i>=self.img2_idx_range[1]:
                img = self.img1
            else:
                img = self.img2
        else:
            raise StopIteration
        self.i += 1
        return img

    def __len__(self):
        return self.length

    def __getitem__(self, i):
        # Support negative indexing
        if i < 0:
            i = len(self) + i

        if 0 <= i and i < self.length:
            if i<self.img2_idx_range[0] or i>=self.img2_idx_range[1]:
                img = self.img1
            else:
                img = self.img2
        else:
            raise IndexError("array index ({}) out of range [0, {})".format(i, len(self)))
        return img

And my training section looks like this:

blank_img = bm.zeros([cfg['height'], cfg['width']], dtype=bm.float_)  # [height, width]

# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())

def step_run(i, x_single):
    bp.share.save(i=i, t=i * bm.get_dt())
    return model.step_run(i, x_single)

# define the loss function
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [height, width]
        y_single: [1]
    '''
    indices = np.arange(cfg['total_timepoint'])  # sequence length
    inputs_compressed = CompressedImage(x_single, blank_img, [cfg['stim_start_timepoint'],cfg['stim_end_timepoint']], cfg['total_timepoint'])

    model.reset_state()        
    spike_out = bm.for_loop(step_run, (indices, inputs_compressed))  # [n_out, length]
    frate_out = bm.sum(spike_out, axis=1) + 1.0e-6   # [n_out]
    predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0)  # log-prababilities, [batch=1, n_out]

    loss = bp.losses.nll_loss(-predicts, y_single)  # Need to manually add a negative sign because BrainPy does not do so. scalar
    acc = bm.mean(predicts.argmax(-1) == y_single)  # scalar
    return loss, acc

def grad_fun(last_grad, input_target):
    x_single, y_single = input_target  # [height, width], scalar
    grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
    grads, loss, acc = grad_f(x_single, y_single[None])
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, height, width]
        y_batch: [batch]
    '''
    grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))
    loss = losses.mean()
    acc = acces.mean()
    opt.update(grads)
    return loss, acc

However the code generates an error in line spike_out = bm.for_loop(step_run, (indices, inputs_compressed)):

*** TypeError: Value <__main__.CompressedImage object at 0x7f10c82b5f10> with type <class '__main__.CompressedImage'> is not a valid JAX type

The problem can potentially be solved by only jitting the step_run() function as in your second recommendation. However, It may be slower than using bm.for_loop since it uses explicit Python for-loop (see updates in issue #552). So I am still not sure about how to pass customized input generator into bm.for_loop().

chaoming0625 commented 5 months ago

This problem can be solved easily. The trick is using where.

blank_img = bm.zeros([cfg['height'], cfg['width']], dtype=bm.float_)  # [height, width]

# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())

def step_run(i, x_single):
    x = bm.where(i < cfg['stim_start_timepoint'],
                           x_single, 
                           blank_img)
    return model.step_run(i, x)

# define the loss function
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [height, width]
        y_single: [1]
    '''
    indices = np.arange(cfg['total_timepoint'])  # sequence length

    model.reset_state()        
    spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), (indices,))  # [n_out, length]
    frate_out = bm.sum(spike_out, axis=1) + 1.0e-6   # [n_out]
    predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0)  # log-prababilities, [batch=1, n_out]

    loss = bp.losses.nll_loss(-predicts, y_single)  # Need to manually add a negative sign because BrainPy does not do so. scalar
    acc = bm.mean(predicts.argmax(-1) == y_single)  # scalar
    return loss, acc

def grad_fun(last_grad, input_target):
    x_single, y_single = input_target  # [height, width], scalar
    grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
    grads, loss, acc = grad_f(x_single, y_single[None])
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, height, width]
        y_batch: [batch]
    '''
    grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))
    loss = losses.mean()
    acc = acces.mean()
    opt.update(grads)
    return loss, acc

The modifications lie in step_run() function and the line:

spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), (indices,)) 
CloudyDory commented 5 months ago

Thank you very much! I know bm.where, but don't know about functools.partial before.

I now have a new error in the gradient accumulation line new_grad = jax.tree_map(bm.add, last_grad, grads):


  File ~/project/snn_model/model_train.py:651 in train
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:997 in scan
    rets = transform(init, operands)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:923 in call
    return jax.lax.scan(f=fun2scan,

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/traceback_util.py:179 in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:258 in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:244 in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py:67 in _initial_style_jaxpr
    jaxpr, consts, out_tree = _initial_style_open_jaxpr(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py:61 in _initial_style_open_jaxpr
    jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/profiler.py:336 in wrapper
    return func(*args, **kwargs)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2288 in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2310 in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:191 in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:914 in fun2scan
    carry, results = body_fun(carry, x)

  File ~/project/snn_model/model_train.py:640 in grad_fun
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/tree_util.py:243 in tree_map
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]

  File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/tree_util.py:243 in <listcomp>
    all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]

ValueError: Custom node type mismatch: expected type: <class 'brainpy._src.math.object_transform.variables.TrainVar'>, value: Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>.

The contents of last_grad and grads are:

IPdb [8]: last_grad
Out  [8]: {'Linear27.weight': TrainVar(value=Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32), 'Alpha55.weight': TrainVar(value=Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32)}

IPdb [9]: grads
Out  [9]: {'Linear27.weight': Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, 'Alpha55.weight': Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>}

The problem seems to be that last_grad is a TrainVar but grads is not. Is there a way to solve the problem?

chaoming0625 commented 5 months ago

One way to fix is new_grad = jax.tree_map(bm.add, last_grad, grads, is_leaf=bm.is_bp_array). Does it fix?

CloudyDory commented 5 months ago

It can fix the error in this line, but introduces a new error in the grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch)) line:

TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ:
  * the input carry component carry[1][<flat index 0>] is a <class 'brainpy._src.math.object_transform.variables.TrainVar'> but the corresponding component of the carry output is a <class 'brainpy._src.math.ndarray.Array'>, so their Python types differ.
  * the input carry component carry[1][<flat index 1>] is a <class 'brainpy._src.math.object_transform.variables.TrainVar'> but the corresponding component of the carry output is a <class 'brainpy._src.math.ndarray.Array'>, so their Python types differ.
Revise the scanned function so that its output is a pair where the first element has the same pytree structure as the first argument.

I can print the contents of last_grad and new_grad in the grad_fun():

IPdb [2]: last_grad
Out  [2]: {'Linear30.weight': TrainVar(value=Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32), 'Alpha61.weight': TrainVar(value=Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32)}

IPdb [3]: new_grad
Out  [3]: {'Linear30.weight': Array(value=Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32), 'Alpha61.weight': Array(value=Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32)}

As you can see, one is a TrainVar but the other is not, so we still have this error.

chaoming0625 commented 5 months ago

OK, the error is caused by the inconsistent types. I will fix it later. But currently, we can use:

new_grad = jax.tree_map(lambda x, y: bm.asarray(bm.add(x, y)), last_grad, grads)  # accumulate gradients

This may fix the error.

CloudyDory commented 5 months ago
new_grad = jax.tree_map(lambda x, y: bm.asarray(bm.add(x, y)), last_grad, grads)  # accumulate gradients

Thanks for the suggestion, but unfortunately this does not fix the error. new_grad is still an Array, not a TrainVar. By the way, we also need to add is_leaf=bm.is_bp_array in this line.

chaoming0625 commented 5 months ago

OK, we should change it as

new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array)  # accumulate gradients

or, we should change the line

grads = {k: bm.zeros(v.shape) for k, v in model.train_vars().unique().items()}
CloudyDory commented 5 months ago

Thanks for the suggestions, the first one works but the second one doesn't. However, with the first suggestion, we have one more problem: the outputs dimension of bm.scan seems to be different with and without bm.jit decoration.

The training function now looks like this:

def train(xs, ys):
    # xs: [batch, height, width]
    grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
    grads, loss_acc = bm.scan(grad_fun, grads, (xs, ys))
    opt.update(grads)
    return loss_acc

If we don't jit this function, the outputs of bm.scan(grad_fun, grads, (xs, ys)) are:

And training can proceed without error.

However, if we jit this function, the outputs becomes:

So now both outputs become a 64-dimensional array, with 64 being the batch size. And the parameter update step opt.update(grads) will throw an error:

MathError: The length of "grads" must be equal to "self.vars_to_train", while we got 64 != 5!

This looks wired to me, as I never expect jit to change the dimension of the outputs.

chaoming0625 commented 5 months ago

The batch_size here you mean is the N_mini_batch?

CloudyDory commented 5 months ago

The batch_size here you mean is the N_mini_batch?

Yes!

CloudyDory commented 5 months ago

The batch_size here you mean is the N_mini_batch?

I have updated the code. xs should have dimention [batch, height, width].

chaoming0625 commented 5 months ago

maybe I understand why the error occurs.


def train(xs, ys):
    # xs: [batch, height, width]
    grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
    grads, loss_acc = bm.scan(grad_fun, grads, (xs, ys))
    opt.update(grads)
    return loss_acc

In your code, grads is an array with size 64, rather than a dict or ArrayCollector` that is expected`. So, theAdamoptimizer evaluates the length ofgrads`` as 64, rather than the number of gradients of corresponding 5 weights.

I think you should check the grads may be wrongly assigned.

chaoming0625 commented 5 months ago

OK, I think this is the bug of bm.scan. Wait a minute.

chaoming0625 commented 5 months ago

604 will fix the issue!

CloudyDory commented 5 months ago

604 will fix the issue!

Thank you very much! I am not an expert in JAX, but what is the cause of this issue?

chaoming0625 commented 5 months ago

The brainpy.math.scan uses two passes to compile the model. The first pass is the evaluation, to find out all Variables used in the code. The second pass is the compilation, to compile the model on the acceleration device.

The error exactly occurred during the first pass, as it does not return the correct outputs new_grad, (loss, acc), instead it returns (loss, acc).

CloudyDory commented 5 months ago

OK, thanks!

chaoming0625 commented 5 months ago

Please reinstall the latest BrainPy codebase after #604 is merged. The update will fix the error, and the whole program will be compiled correctly. We are sorry for the inconvenience.

CloudyDory commented 5 months ago

No problem. Thank you very much!