brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
493 stars 90 forks source link

Gradient accumulation generates JAX leaking error in BrainPy 2.5.0 #627

Closed CloudyDory closed 5 months ago

CloudyDory commented 5 months ago

After upgrading to BrainPy 2.5.0, I found that training by gradient accumulation does not work in the newest version.

We can use logistic regression as an example:

import numpy as np
import jax
import brainpy as bp
import brainpy.math as bm

bm.clear_buffer_memory()
bm.set(float_=bm.float32)
bm.set_platform('cpu')

#%% Network definition
class Network(bp.DynSysGroup):
    def __init__(self):
        super().__init__()
        self.weight = bm.TrainVar(bm.random.randn(2))

    def update(self, data):
        out = bm.sum(self.weight * data)
        return out

#%% Create network and fake data
print('Creating network... ')
with bm.training_environment():
    model = Network()
    optimizer = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())

print('Creating data... ')
train_data = np.concatenate([np.random.randn(100, 2) + np.array([[-1,-1]]),
                             np.random.randn(100, 2) + np.array([[ 1, 1]])], axis=0)  # [batch, 2]
train_label = bm.concatenate([bm.zeros(100, dtype=bm.int32), 
                              bm.ones(100, dtype=bm.int32)], axis=0)  # [batch]

#%% Training functions
def loss_fun(x_single, y_single):
    '''
    Inputs:
        x_single: [feature]
        y_single: scalar
    '''
    predict = model.update(x_single)  # scalar
    loss = bp.losses.binary_logistic_loss(predict, y_single)  # scalar
    acc = bm.mean(bm.int32(predict >= 0.0) == y_single)  # scalar
    return loss, acc

grad_f = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), has_aux=True, return_value=True)

def grad_fun(last_grad, x_y_single):
    '''
    Inputs:
        last_grad: PyTree of gradients of each trainable parameter.
        x_y_single: tuple of ([feature], scalar), a single training sample.
    '''
    x_single, y_single = x_y_single  # [feature], scalar
    grads, loss, acc = grad_f(x_single, y_single)  # PyTree of gradients, scalar, scalar
    new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array)  # accumulate gradients
    return new_grad, (loss, acc)

@bm.jit
def train(x_batch, y_batch):
    '''
    Inputs:
        x_batch: [batch, feature]
        y_batch: [batch]
    '''
    train_vars = model.train_vars().unique()

    # Gradient accumulation
    grads = jax.tree_map(bm.zeros_like, train_vars)
    grads, (losses, acces) = bm.scan(grad_fun, grads, (x_batch, y_batch))  # PyTree of gradients, [batch], [batch]
    optimizer.update(grads)

    loss = losses.mean()  # scalar
    acc = acces.mean()    # scalar
    return loss, acc

#%% Start training
print('Start training...')
train_epochs = 15
train_loss = bm.zeros(train_epochs, dtype=bm.float_)
train_acc = bm.zeros(train_epochs, dtype=bm.float_)

for e in range(train_epochs):
    train_loss[e], train_acc[e] = train(train_data, train_label)
    print("Epoch {}, train_loss={:.3f}, train_acc={:.2f}%".format(e, train_loss[e], train_acc[e]*100.0))

print('Done!')

On BrainPy 2.4.6.post5, the above code trains normally. But on BrainPy 2.5.0, the above code generates the following error:

UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was fun2scan at /home/xxx/miniconda3/envs/brainpy2.5/lib/python3.11/site-packages/brainpy/_src/math/object_transform/controls.py:929 traced for scan.
------------------------------
The leaked intermediate value was created on line /home/xxx/project/test_train_bug.py:62:53 (grad_fun.<locals>.<lambda>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_571020/323626455.py:1 (<module>)
/home/xxx/project/test_train_bug.py:90:34 (<module>)
/home/xxx/project/test_train_bug.py:76:29 (train)
/home/xxx/project/test_train_bug.py:62:15 (grad_fun)
/home/xxx/project/test_train_bug.py:62:53 (grad_fun.<locals>.<lambda>)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Environment (BrainPy 2.5.0):

Environment (BrainPy 2.4.6.post5):

chaoming0625 commented 5 months ago

Thanks for the report!

The problem can be fixed by changing the line

new_grad = jax.tree_map(lambda x, y: bm.TrainVar(bm.add(x, y)), last_grad, grads, is_leaf=bm.is_bp_array)  # accumulate gradients

into

new_grad = jax.tree_map(bm.add, last_grad, grads)  # accumulate gradients

Please let me know whether the changes fix the error.

CloudyDory commented 5 months ago

Thank you very much for the reply, it fixes the error. Could you briefly explain why does it happen?

chaoming0625 commented 5 months ago

The error caused here is somehow not intuitive. This involves the issue of understanding the variable tracing in BrainPy. I do not encourage you to understand this error. :joy::joy: