Closed CloudyDory closed 5 months ago
Actually, they are all using brainpy.math.for_loop
. Therefore, DSTrainer
can also support the generation of data at runtime. But I recommend to use brainpy.math.for_loop
or brainpy.math.jit
instead.
For example, if your data is organized as:
def your_data_at_each_time_step():
return # something
def step_run(i):
inp = your_data_at_each_time_step()
return model.step_run(i, inp)
outs = bm.for_loop(step_run, np.arange(1000))
However, if your data is generated or loaded from the external source, you can use jit
:
@bm.jit
def step_run(i, inp):
return model.step_run(i, inp)
for i in np.arange(10000):
# load data
inp = ....
out = step_run(i, inp)
Probably I need to provide a more complete picture of the problem I meet here.
The training process is:
[batch, height, width]
and the corresponding class label with dimension [batch]
. [length, height, width]
, where the sequences are the same training image for idx1 <= i < idx2
, and zeros for 0 <= i < idx1
or idx2 <= i < length
.length
.The problem is in step 2, length
can be quite large, and the extended image sequence can occupy too much memory. Therefore, I have created a customed indexable class to store the sequence in a compressed form. It only stores the two unique images in the sequence, and select one as output given the query index.
from collections.abc import Sequence
# Iterator object to efficient store the input image
class CompressedImage(Sequence):
def __init__(self, img1, img2, img2_idx_range, length):
self.i = 0
self.img1 = img1
self.img2 = img2
self.img2_idx_range = img2_idx_range
self.length = length # in number of time points
self.shape = (length,) + img1.shape
def __iter__(self):
return self
def __next__(self):
if self.i < self.length:
if self.i<self.img2_idx_range[0] or self.i>=self.img2_idx_range[1]:
img = self.img1
else:
img = self.img2
else:
raise StopIteration
self.i += 1
return img
def __len__(self):
return self.length
def __getitem__(self, i):
# Support negative indexing
if i < 0:
i = len(self) + i
if 0 <= i and i < self.length:
if i<self.img2_idx_range[0] or i>=self.img2_idx_range[1]:
img = self.img1
else:
img = self.img2
else:
raise IndexError("array index ({}) out of range [0, {})".format(i, len(self)))
return img
And my training section looks like this:
blank_img = bm.zeros([cfg['height'], cfg['width']], dtype=bm.float_) # [height, width]
# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
def step_run(i, x_single):
bp.share.save(i=i, t=i * bm.get_dt())
return model.step_run(i, x_single)
# define the loss function
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [height, width]
y_single: [1]
'''
indices = np.arange(cfg['total_timepoint']) # sequence length
inputs_compressed = CompressedImage(x_single, blank_img, [cfg['stim_start_timepoint'],cfg['stim_end_timepoint']], cfg['total_timepoint'])
model.reset_state()
spike_out = bm.for_loop(step_run, (indices, inputs_compressed)) # [n_out, length]
frate_out = bm.sum(spike_out, axis=1) + 1.0e-6 # [n_out]
predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0) # log-prababilities, [batch=1, n_out]
loss = bp.losses.nll_loss(-predicts, y_single) # Need to manually add a negative sign because BrainPy does not do so. scalar
acc = bm.mean(predicts.argmax(-1) == y_single) # scalar
return loss, acc
def grad_fun(last_grad, input_target):
x_single, y_single = input_target # [height, width], scalar
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
grads, loss, acc = grad_f(x_single, y_single[None])
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, height, width]
y_batch: [batch]
'''
grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))
loss = losses.mean()
acc = acces.mean()
opt.update(grads)
return loss, acc
However the code generates an error in line spike_out = bm.for_loop(step_run, (indices, inputs_compressed))
:
*** TypeError: Value <__main__.CompressedImage object at 0x7f10c82b5f10> with type <class '__main__.CompressedImage'> is not a valid JAX type
The problem can potentially be solved by only jitting the step_run()
function as in your second recommendation. However, It may be slower than using bm.for_loop
since it uses explicit Python for-loop (see updates in issue #552). So I am still not sure about how to pass customized input generator into bm.for_loop()
.
This problem can be solved easily. The trick is using where
.
blank_img = bm.zeros([cfg['height'], cfg['width']], dtype=bm.float_) # [height, width]
# define the optimizer we need
opt = bp.optim.Adam(lr=1e-3, train_vars=model.train_vars().unique())
def step_run(i, x_single):
x = bm.where(i < cfg['stim_start_timepoint'],
x_single,
blank_img)
return model.step_run(i, x)
# define the loss function
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [height, width]
y_single: [1]
'''
indices = np.arange(cfg['total_timepoint']) # sequence length
model.reset_state()
spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), (indices,)) # [n_out, length]
frate_out = bm.sum(spike_out, axis=1) + 1.0e-6 # [n_out]
predicts = bm.log(frate_out / bm.sum(frate_out)).unsqueeze(0) # log-prababilities, [batch=1, n_out]
loss = bp.losses.nll_loss(-predicts, y_single) # Need to manually add a negative sign because BrainPy does not do so. scalar
acc = bm.mean(predicts.argmax(-1) == y_single) # scalar
return loss, acc
def grad_fun(last_grad, input_target):
x_single, y_single = input_target # [height, width], scalar
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
grads, loss, acc = grad_f(x_single, y_single[None])
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, height, width]
y_batch: [batch]
'''
grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))
loss = losses.mean()
acc = acces.mean()
opt.update(grads)
return loss, acc
The modifications lie in step_run()
function and the line:
spike_out = bm.for_loop(functools.partial(step_run, x_single=x_single), (indices,))
Thank you very much! I know bm.where
, but don't know about functools.partial
before.
I now have a new error in the gradient accumulation line new_grad = jax.tree_map(bm.add, last_grad, grads)
:
File ~/project/snn_model/model_train.py:651 in train
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:997 in scan
rets = transform(init, operands)
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:923 in call
return jax.lax.scan(f=fun2scan,
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/traceback_util.py:179 in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:258 in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/loops.py:244 in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py:67 in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/lax/control_flow/common.py:61 in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/profiler.py:336 in wrapper
return func(*args, **kwargs)
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2288 in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/interpreters/partial_eval.py:2310 in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/linear_util.py:191 in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:914 in fun2scan
carry, results = body_fun(carry, x)
File ~/project/snn_model/model_train.py:640 in grad_fun
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/tree_util.py:243 in tree_map
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
File ~/miniconda3/envs/brainpy/lib/python3.11/site-packages/jax/_src/tree_util.py:243 in <listcomp>
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Custom node type mismatch: expected type: <class 'brainpy._src.math.object_transform.variables.TrainVar'>, value: Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>.
The contents of last_grad
and grads
are:
IPdb [8]: last_grad
Out [8]: {'Linear27.weight': TrainVar(value=Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32), 'Alpha55.weight': TrainVar(value=Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32)}
IPdb [9]: grads
Out [9]: {'Linear27.weight': Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, 'Alpha55.weight': Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>}
The problem seems to be that last_grad
is a TrainVar
but grads
is not. Is there a way to solve the problem?
One way to fix is new_grad = jax.tree_map(bm.add, last_grad, grads, is_leaf=bm.is_bp_array)
. Does it fix?
It can fix the error in this line, but introduces a new error in the grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))
line:
TypeError: Scanned function carry input and carry output must have the same pytree structure, but they differ:
* the input carry component carry[1][<flat index 0>] is a <class 'brainpy._src.math.object_transform.variables.TrainVar'> but the corresponding component of the carry output is a <class 'brainpy._src.math.ndarray.Array'>, so their Python types differ.
* the input carry component carry[1][<flat index 1>] is a <class 'brainpy._src.math.object_transform.variables.TrainVar'> but the corresponding component of the carry output is a <class 'brainpy._src.math.ndarray.Array'>, so their Python types differ.
Revise the scanned function so that its output is a pair where the first element has the same pytree structure as the first argument.
I can print the contents of last_grad
and new_grad
in the grad_fun()
:
IPdb [2]: last_grad
Out [2]: {'Linear30.weight': TrainVar(value=Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32), 'Alpha61.weight': TrainVar(value=Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32)}
IPdb [3]: new_grad
Out [3]: {'Linear30.weight': Array(value=Traced<ShapedArray(float32[3506881])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32), 'Alpha61.weight': Array(value=Traced<ShapedArray(float32[70139111])>with<DynamicJaxprTrace(level=2/0)>, dtype=float32)}
As you can see, one is a TrainVar
but the other is not, so we still have this error.
OK, the error is caused by the inconsistent types. I will fix it later. But currently, we can use:
new_grad = jax.tree_map(lambda x, y: bm.asarray(bm.add(x, y)), last_grad, grads) # accumulate gradients
This may fix the error.
new_grad = jax.tree_map(lambda x, y: bm.asarray(bm.add(x, y)), last_grad, grads) # accumulate gradients
Thanks for the suggestion, but unfortunately this does not fix the error. new_grad
is still an Array
, not a TrainVar
.
By the way, we also need to add is_leaf=bm.is_bp_array
in this line.
OK, we should change it as
new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array) # accumulate gradients
or, we should change the line
grads = {k: bm.zeros(v.shape) for k, v in model.train_vars().unique().items()}
Thanks for the suggestions, the first one works but the second one doesn't. However, with the first suggestion, we have one more problem: the outputs dimension of bm.scan
seems to be different with and without bm.jit
decoration.
The training function now looks like this:
def train(xs, ys):
# xs: [batch, height, width]
grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
grads, loss_acc = bm.scan(grad_fun, grads, (xs, ys))
opt.update(grads)
return loss_acc
If we don't jit this function, the outputs of bm.scan(grad_fun, grads, (xs, ys))
are:
grads
: a dict-like object of <class 'brainpy._src.math.object_transform.collectors.ArrayCollector'>
. Each key contains an array with dimension being the same as the corresponding network parameter.loss_acc
: a tuple containing two 1-dimensional array of shape (batch_size,)
.And training can proceed without error.
However, if we jit this function, the outputs becomes:
grads
: Traced<ShapedArray(float32[64])>with<DynamicJaxprTrace(level=1/0)>
loss_acc
: Traced<ShapedArray(float32[64])>with<DynamicJaxprTrace(level=1/0)>
So now both outputs become a 64-dimensional array, with 64 being the batch size. And the parameter update step opt.update(grads)
will throw an error:
MathError: The length of "grads" must be equal to "self.vars_to_train", while we got 64 != 5!
This looks wired to me, as I never expect jit to change the dimension of the outputs.
The batch_size
here you mean is the N_mini_batch
?
The
batch_size
here you mean is theN_mini_batch
?
Yes!
The
batch_size
here you mean is theN_mini_batch
?
I have updated the code. xs
should have dimention [batch, height, width].
maybe I understand why the error occurs.
def train(xs, ys):
# xs: [batch, height, width]
grads = jax.tree_map(bm.zeros_like, model.train_vars().unique())
grads, loss_acc = bm.scan(grad_fun, grads, (xs, ys))
opt.update(grads)
return loss_acc
In your code, grads
is an array with size 64, rather than a dict or ArrayCollector` that is expected`. So, the
Adamoptimizer evaluates the length of
grads`` as 64, rather than the number of gradients of corresponding 5 weights.
I think you should check the grads
may be wrongly assigned.
OK, I think this is the bug of bm.scan
. Wait a minute.
604 will fix the issue!
Thank you very much! I am not an expert in JAX, but what is the cause of this issue?
The brainpy.math.scan
uses two passes to compile the model. The first pass is the evaluation, to find out all Variable
s used in the code. The second pass is the compilation, to compile the model on the acceleration device.
The error exactly occurred during the first pass, as it does not return the correct outputs new_grad, (loss, acc)
, instead it returns (loss, acc)
.
OK, thanks!
Please reinstall the latest BrainPy codebase after #604 is merged. The update will fix the error, and the whole program will be compiled correctly. We are sorry for the inconvenience.
No problem. Thank you very much!
When simulating a network, sometimes it is not convenient to pre-generate the entire input data due to memory constraints. If we use
DSRunner
, we can supply a iterable or functional input when initializing the runner, so the input data can be stored in a compressed form or generated at runtime. However, it seems that this is not possible when usingDSTrainer
. TheDSTrainer
object requires supplying the entire input data when calling thepredict()
function. Is it possible to overcome this issue?