Closed CloudyDory closed 9 months ago
Thanks for the report. The rotation
mode may be fixed by sometimes before. But i will check whether the gradients is correct.
Thanks for the report. The
rotation
mode may be fixed by sometimes before. But i will check whether the gradients is correct.
This is also what I hope to know. How to check gradients in BrainPy?
I write a simple code to check whether the gradients are the same. The answer is yes.
import functools
import jax
import numpy as np
import brainpy as bp
import brainpy.math as bm
snn_latency = 20
dt = 1.0
bm.clear_buffer_memory()
bm.set(float_=bm.float32, mode=bm.training_mode)
bm.set_platform('cpu')
bm.set_dt(dt)
# %% Network definition
class Network(bp.DynSysGroup):
def __init__(self, method):
super().__init__()
self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
self.delay_len = 2
self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method=method)
self.weight = bm.TrainVar(bm.random.randn(2, 2))
self.bias = bm.TrainVar(bm.random.randn(2))
def reset_state(self, *args):
self.neu.reset_state(self.neu.mode)
self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)
def update(self, data):
spike = self.neu(self.weight @ data + self.bias) # [batch, 2]
self.spike_buffer.update(spike)
spike_delay = self.spike_buffer.retrieve(self.delay_len) # [batch, 2]
return spike_delay
def train1(method='rotation'):
# %% Create network and fake data
model = Network(method)
optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())
# %% Training functions
def loss_fun(x_single, y_single):
'''
Inputs:
x_single: [feature]
y_single: [1]
'''
indices = np.arange(snn_latency) # sequence length
model.reset_state()
spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices) # [length, batch=1, 2], float32
firerate = bm.sum(spike, axis=0) + 1.0e-6 # [batch=1, 2]
predict = bm.log(firerate / bm.sum(firerate)) # log-prababilities, [batch=1, n_class]
loss = bp.losses.nll_loss(-predict, y_single) # scalar
acc = bm.mean(predict.argmax(-1) == y_single) # scalar
return loss, acc
def grad_fun(last_grad, x_y_single):
'''
Inputs:
last_grad: PyTree of gradients of each trainable parameter.
x_y_single: tuple of ([feature], scalar), a single training sample.
'''
x_single, y_single = x_y_single # [feature], scalar
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
grads, loss, acc = grad_f(x_single, y_single[None]) # PyTree of gradients, scalar, scalar
new_grad = jax.tree_map(bm.add, last_grad, grads) # accumulate gradients
return new_grad, (loss, acc)
@bm.jit
def train(x_batch, y_batch):
'''
Inputs:
x_batch: [batch, feature]
y_batch: [batch]
'''
train_vars = model.train_vars().unique()
# Gradient accumulation
grads = jax.tree_map(bm.zeros_like, train_vars)
grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch)) # PyTree of gradients, [batch], [batch]
optimizer.update(grads)
return grads
return train
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1, -1]]),
np.random.randn(100, 2) + np.array([[1, 1]])], axis=0) # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
bm.ones(100, dtype=bm.int32)], axis=0) # [batch]
bm.random.seed(0)
bm.clear_name_cache()
f1 = train1('rotation')
bm.random.seed(0)
bm.clear_name_cache()
f2 = train1('concat')
for e in range(10):
# with jax.disable_jit():
grad1 = f1(train_data, train_label)
grad2 = f2(train_data, train_label)
print(jax.tree_map(bm.allclose, grad1, grad2))
Hi, I actually hope to know where are the gradient stored in BrainPy. For example, in PyTorch there is a grad
field in the trained parameters which stored the gradient values. Is there a similar field in BrainPy variables?
The gradients do not have a fixed place to store. It is only returned after the function is computed. For the following example, the gradient has stored as grads
:
# "grad_vars" specify the target to compute gradients
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
# "grads" return as the function output
grads, loss, acc = grad_f(x_single, y_single[None])
Thank you very much for the information!
In #626 we mentioned that the rotation method in delay variables does not implement an autograd functionality. However I have tested this in training and found that the parameters can be trained normally. Is there a misunderstanding on the issue?
Outputs: