brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
514 stars 93 forks source link

Questions about backpropagation through delay variables #633

Closed CloudyDory closed 7 months ago

CloudyDory commented 7 months ago

In #626 we mentioned that the rotation method in delay variables does not implement an autograd functionality. However I have tested this in training and found that the parameters can be trained normally. Is there a misunderstanding on the issue?

import functools
import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm

snn_latency = 20
dt = 1.0

bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')
bm.set_dt(dt)

#%% Network definition
class Network(bp.DynSysGroup):
    def __init__(self):
        super().__init__()
        self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
        self.delay_len = 2
        self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method='rotation')
        self.weight = bm.TrainVar(bm.random.randn(2,2))
        self.bias = bm.TrainVar(bm.random.randn(2))

    def reset_state(self, *args):
        self.neu.reset_state(self.neu.mode)
        self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)

    def update(self, data):
        spike = self.neu(self.weight @ data + self.bias)  # [batch, 2] 
        self.spike_buffer.update(spike)
        spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2] 
        return spike_delay

#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
    model = Network()
    optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())

print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
                             np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32), 
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

#%% Training functions
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: [1]
    '''
    indices = np.arange(snn_latency)  # sequence length

    model.reset_state()
    spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices)  # [length, batch=1, 2], float32
    firerate = bm.sum(spike, axis=0) + 1.0e-6  # [batch=1, 2]
    predict = bm.log(firerate / bm.sum(firerate))  # log-prababilities, [batch=1, n_class]

    loss = bp.losses.nll_loss(-predict, y_single)  # scalar
    acc = bm.mean(predict.argmax(-1) == y_single)  # scalar
    return loss, acc

grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)

def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grads, loss, acc = grad_f(x_single, y_single[None])  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()

    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)

    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    return loss, acc

#%% Start training
print('Start training...')
train_epochs = 10
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)

for e in range(train_epochs):
    # with jax.disable_jit():
    train_loss[e], train_acc[e] = train(train_data, train_label)
    print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))

print('Done!')

Outputs:

Creating network... 
Creating data... 
Start training...
Epoch 0, train_loss=4.572, train_acc=50.00%
Epoch 1, train_loss=1.373, train_acc=77.00%
Epoch 2, train_loss=0.757, train_acc=87.00%
Epoch 3, train_loss=0.767, train_acc=89.50%
Epoch 4, train_loss=0.636, train_acc=91.00%
Epoch 5, train_loss=0.642, train_acc=91.00%
Epoch 6, train_loss=0.568, train_acc=89.50%
Epoch 7, train_loss=0.570, train_acc=89.50%
Epoch 8, train_loss=0.576, train_acc=90.00%
Epoch 9, train_loss=0.582, train_acc=91.00%
Done!
chaoming0625 commented 7 months ago

Thanks for the report. The rotation mode may be fixed by sometimes before. But i will check whether the gradients is correct.

CloudyDory commented 7 months ago

Thanks for the report. The rotation mode may be fixed by sometimes before. But i will check whether the gradients is correct.

This is also what I hope to know. How to check gradients in BrainPy?

chaoming0625 commented 7 months ago

I write a simple code to check whether the gradients are the same. The answer is yes.

import functools

import jax
import numpy as np

import brainpy as bp
import brainpy.math as bm

snn_latency = 20
dt = 1.0

bm.clear_buffer_memory()
bm.set(float_=bm.float32, mode=bm.training_mode)
bm.set_platform('cpu')
bm.set_dt(dt)

# %% Network definition
class Network(bp.DynSysGroup):
  def __init__(self, method):
    super().__init__()
    self.neu = bp.dyn.Lif(size=2, V_rest=0.0, V_reset=0.0, V_th=1.0, spk_fun=bm.surrogate.Arctan())
    self.delay_len = 2
    self.spike_buffer = bm.LengthDelay(self.neu.spike, delay_len=self.delay_len, update_method=method)
    self.weight = bm.TrainVar(bm.random.randn(2, 2))
    self.bias = bm.TrainVar(bm.random.randn(2))

  def reset_state(self, *args):
    self.neu.reset_state(self.neu.mode)
    self.spike_buffer.reset(self.neu.spike, delay_len=self.delay_len)

  def update(self, data):
    spike = self.neu(self.weight @ data + self.bias)  # [batch, 2]
    self.spike_buffer.update(spike)
    spike_delay = self.spike_buffer.retrieve(self.delay_len)  # [batch, 2]
    return spike_delay

def train1(method='rotation'):
  # %% Create network and fake data
  model = Network(method)
  optimizer = bp.optim.Adam(lr=1.0, train_vars=model.train_vars().unique())

  # %% Training functions
  def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: [1]
    '''
    indices = np.arange(snn_latency)  # sequence length

    model.reset_state()
    spike = bm.for_loop(functools.partial(model.step_run, data=x_single), indices)  # [length, batch=1, 2], float32
    firerate = bm.sum(spike, axis=0) + 1.0e-6  # [batch=1, 2]
    predict = bm.log(firerate / bm.sum(firerate))  # log-prababilities, [batch=1, n_class]

    loss = bp.losses.nll_loss(-predict, y_single)  # scalar
    acc = bm.mean(predict.argmax(-1) == y_single)  # scalar
    return loss, acc

  def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
    grads, loss, acc = grad_f(x_single, y_single[None])  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients
    return new_grad, (loss, acc)

  @bm.jit
  def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()
    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)
    return grads

  return train

train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1, -1]]),
                             np.random.randn(100, 2) + np.array([[1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32),
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

bm.random.seed(0)
bm.clear_name_cache()
f1 = train1('rotation')

bm.random.seed(0)
bm.clear_name_cache()
f2 = train1('concat')

for e in range(10):
  # with jax.disable_jit():
  grad1 = f1(train_data, train_label)
  grad2 = f2(train_data, train_label)
  print(jax.tree_map(bm.allclose, grad1, grad2))
CloudyDory commented 7 months ago

Hi, I actually hope to know where are the gradient stored in BrainPy. For example, in PyTorch there is a grad field in the trained parameters which stored the gradient values. Is there a similar field in BrainPy variables?

chaoming0625 commented 7 months ago

The gradients do not have a fixed place to store. It is only returned after the function is computed. For the following example, the gradient has stored as grads:

# "grad_vars" specify the target to compute gradients
grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)
# "grads" return as the function output
grads, loss, acc = grad_f(x_single, y_single[None])
CloudyDory commented 7 months ago

Thank you very much for the information!