brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
491 stars 90 forks source link

multiple GPU devices simulation and training of one dynamic system in brainpy #641

Open Dr-Chen-Xiaoyu opened 4 months ago

Dr-Chen-Xiaoyu commented 4 months ago

Hi, Chaoming:

I am trying to do simulation and training of a dynamic system (a self customized RNN based on brainpy, https://github.com/Dr-Chen-Xiaoyu/DecoModel) with very huge dimension and time steps. The memory usage is out of one single GPU device.

I believe this could be solved by running brainpy on multiple GPU devices with its own sharding method, just like jax's sharding or pytorch's torch.nn.DataParallel. A simplified case of RNN training is provided below, and change the dimension of RNN to very huge (maybe >1000) as well as the input output tensor (maybe >1000^3). Maybe you could modify this code with brainpy's sharding and make it as part of brainpy's tutorial if this is a general demand of users.

best, Xiaoyu

The example code:

# %%
import os,jax
import numpy as np
import matplotlib.pyplot as plt

import brainpy as bp
import brainpy.math as bm
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1" # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)

print('bp version:', bp.__version__)
print(jax.local_devices())
#bp version: 2.4.6.post5
#[cuda(id=0), cuda(id=1)]

# %%
class RNN(bp.DynamicalSystemNS):
    def __init__(self, num_in, num_hid, num_out, batch_size=1):
        super(RNN, self).__init__()

        bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))

        # define parameters
        self.num_in  = num_in
        self.num_hid = num_hid
        self.num_out = num_out

        # define variables
        self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)

        # define weights
        self.win  = bm.TrainVar(bm.random.normal(0., 1., size=(num_in,  num_hid)))
        self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
        self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))

    def reset_state(self, batch_size):# this function defines how to reset the mode states
        self.state.value = bm.zeros((batch_size, self.num_hid))

    def update(self, x):# this function defined how the model update its state and produce its output
        self.state.value = bm.tanh( bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec) )
        return bm.matmul(self.state, self.wout)

# initialize model
bm.random.seed(123)
dim_in =1
dim_hid=10
dim_out=1
batch_size=1
model = RNN(dim_in, dim_hid, dim_out , batch_size)

# %%
# generate some data
Nsample = 500
X_train = bm.random.normal(0.,1., size=(batch_size ,Nsample,dim_in)) #(Batch,Time,dim)
Y_train = bm.random.normal(10.,1., size=(batch_size, Nsample,dim_out))

def plot_model_predict(model,X_train,Y_train):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    Y_model = runner.run(inputs=X_train)

    plt.plot(X_train[0,:,:])
    plt.plot(Y_train[0,:,:])
    plt.plot(Y_model[0,:,:])
    plt.show()
plot_model_predict(model,X_train,Y_train)

# %%
# training
def loss_fun(inputs, targets):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

grad_fun = bm.grad(loss_fun,grad_vars=model.train_vars().unique(),return_value=True)

opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())

@bm.jit
def train(xs, ys):
    grads, loss = grad_fun(xs, ys)
    opt.update(grads)
    return loss

losses=[]
for _ in range(1000):
    losses.append(train(X_train,Y_train))

plt.plot(losses);plt.show()
plot_model_predict(model,X_train,Y_train)
Dr-Chen-Xiaoyu commented 4 months ago

I think I might find the way to sharding bm.array based on JAX's tutorial https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html :


# %%
import jax
import jax.numpy as jnp

import os
import numpy as np

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1" # specify which GPU(s) to be used
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"
import brainpy as bp
import brainpy.math as bm
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
print('bp version:', bp.__version__)

from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding
from jax.sharding import PartitionSpec as P

# %%
def get_sharding_details(sharded_data):
    # We can get detailed information for each shard
    print("="*75)
    for i, shard in enumerate(sharded_data.global_shards):
        print(f"Shard no: {i:>5}")
        print(f"Device: {str(shard.device):>32}")
        print(f"Data shape: {str(shard.data.shape):>8}")
        print(f"Data slices: {str(shard.index):>22}")
        print("="*75)

# %%
devices = mesh_utils.create_device_mesh((len(jax.local_devices()),))
print(f"Device Array: {devices}")

# Create a mesh from the device array
mesh = Mesh(devices, axis_names=("ax"))

# Define sharding with a partiton spec
sharding = NamedSharding(mesh, P("ax"))

print(mesh)

# %%
a = jnp.ones((1000,1000,3))
get_sharding_details(a)

print("\nafter sharding:\n")

# Shard the data
b = jax.device_put(a, sharding)
get_sharding_details(b)

# %%
c = bm.ones((1000,1000,3))
get_sharding_details(c.value)

print("\nafter sharding:\n")

# Shard the data
d = bm.sharding.partition_by_sharding(c, sharding)
get_sharding_details(d.value)

Maybe just sharding the input output bm.array tensor along the batch axis, and then let it automatically calculate on multi-GPUs ? Just some thought 😊

Dr-Chen-Xiaoyu commented 4 months ago

print is something like that before- and after-sharding array:

===========================================================================
Shard no:     0
Device:                           cuda:0
Data shape: (1000, 1000, 3)
Data slices: (slice(None, None, None), slice(None, None, None), slice(None, None, None))
===========================================================================

after sharding:

===========================================================================
Shard no:     0
Device:                           cuda:0
Data shape: (500, 1000, 3)
Data slices: (slice(0, 500, None), slice(None, None, None), slice(None, None, None))
===========================================================================
Shard no:     1
Device:                           cuda:1
Data shape: (500, 1000, 3)
Data slices: (slice(500, 1000, None), slice(None, None, None), slice(None, None, None))
===========================================================================
chaoming0625 commented 4 months ago

Thanks for the question. Sorry for the slow response. I will check it later.

Dr-Chen-Xiaoyu commented 2 months ago

Hi, chaoming @chaoming0625

Maybe this issue is a bit hard with too many engineering works to achieve. 🫡

I just have an idea about a quick and cheap solution of this issue. As to #663 , if any built-in or customized brainpy dynamical system class could be automatically transformed into Flax's RNN cell using bp.dnn.ToFlaxRNNCell(). Then, we could just do multi-GPU parallel training using Flax (https://flax.readthedocs.io/en/latest/guides/parallel_training/index.html). 🤖

best, Xiaoyu Chen

chaoming0625 commented 2 months ago

yes, the idea is simple. I will give you the solution soon.

chaoming0625 commented 2 months ago

Here is my example of using multiple GPUs. I marked the key code by using the comment [KEY].


import os
import jax

import brainpy as bp
import brainpy.math as bm

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # specify which GPU(s) to be used
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')
bm.set_mode(bm.training_mode)

print('bp version:', bp.__version__)
print(jax.local_devices())

# bp version: 2.4.6.post5
# [cuda(id=0), cuda(id=1)]

# %%
class RNN(bp.DynamicalSystemNS):
  def __init__(self, num_in, num_hid, num_out, batch_size=1):
    super(RNN, self).__init__()

    bp.check.is_subclass(self.mode, (bm.TrainingMode, bm.BatchingMode))

    # define parameters
    self.num_in = num_in
    self.num_hid = num_hid
    self.num_out = num_out

    # define variables [KEY]
    self.state = bp.init.variable(bm.zeros, num_hid, batch_size, axis_names=['hidden'])
    # self.state = bm.Variable(bm.zeros((batch_size, num_hid)), batch_axis=0)

    # define weights [KEY]
    self.win = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_in, num_hid), axis_names=[None, 'hidden']))
    self.wrec = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_hid), axis_names=[None, 'hidden']))
    self.wout = bm.TrainVar(bp.init.variable_(bp.init.Normal(), (num_hid, num_out), axis_names=['hidden', None]))
    # self.win = bm.TrainVar(bm.random.normal(0., 1., size=(num_in, num_hid)))
    # self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
    # self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out)))

  def reset_state(self, batch_size):  # this function defines how to reset the mode states
    self.state.value = bp.init.variable_(bm.zeros, (self.num_hid,), batch_size, axis_names=['hidden'])

  def update(self, x):  # this function defined how the model update its state and produce its output
    self.state.value = bm.tanh(bm.matmul(x, self.win) + bm.matmul(self.state, self.wrec))
    return bm.matmul(self.state, self.wout)

with bm.sharding.device_mesh(jax.devices(), ['hidden']):  # [KEY]
  # initialize model
  bm.random.seed(123)
  dim_in = 1
  dim_hid = 10
  dim_out = 1
  batch_size = 1
  model = RNN(dim_in, dim_hid, dim_out, batch_size)

  # %%
  # generate some data
  Nsample = 500
  X_train = bm.random.normal(0., 1., size=(batch_size, Nsample, dim_in))  # (Batch,Time,dim)
  Y_train = bm.random.normal(10., 1., size=(batch_size, Nsample, dim_out))

  # training
  def loss_fun(inputs, targets):
    runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
    predicts = runner.predict(inputs)
    loss = bp.losses.mean_squared_error(predicts, targets)
    return loss

  grad_fun = bm.grad(loss_fun, grad_vars=model.train_vars().unique(), return_value=True)

  opt = bp.optim.Adam(lr=1e-1, train_vars=model.train_vars().unique())

  @bm.jit
  def train(xs, ys):
    grads, loss = grad_fun(xs, ys)
    opt.update(grads)
    return loss

  losses = []
  for _ in range(1000):
    losses.append(train(X_train, Y_train))
chaoming0625 commented 2 months ago

The concept is very simple.

  1. initialize a context manager to setup a device mesh. Here
with bm.sharding.device_mesh(devices, ['hidden']):
   ...

means that the hidden dimension will be partitioned on the given devices.

Note that the devices should be the same dimension as the hidden. For example, if you want to partition the model onto two-dimensional devices by input and hidden, We should set up a context as:

with bm.sharding.device_mesh(np,asarray(jax.devices(), (2, 2)), ['input', 'hidden']):
   ...
  1. Initializing the variable of weights by using brainpy.init.variable_(...., axis_names=['input', 'hidden']). The data will be automatically partitioned on the devices if the given axis name matches the device mesh axis.

  2. using brainpy.math.jit. This is the key to the parralelization. All functions should have a jit decorator, otherwise, the model will not be parallelized according to the setting.

chaoming0625 commented 2 months ago

Please tell me whether the above code works.

Please also see an example of TPU multi-device partition examples of COBA-HH network model.

chaoming0625 commented 2 months ago

By the way, I apologize for the very late response!

Dr-Chen-Xiaoyu commented 2 months ago

It works! Thanks so much!🫰

For model without sharding: image

After using sharding, the memory is shared by two GPU cards with 2x faster🫡: image

chaoming0625 commented 2 months ago

Thanks for the feedback!

Dr-Chen-Xiaoyu commented 2 months ago

One more question about the details. it seems that you partition the model (the hidden states of this RNN) into two GPUs. Why not partition along the batch axis? it seems more nature for users.

chaoming0625 commented 2 months ago

This is a good idea. While, if the batch size is the challenge hindering the training of the model on one GPU, we can decrease the batch size, rather than partition it on multiple devices. One more difficult situation is that the model is too big to install on one device. For such cases, we can partition the model on multiple devices. For example, simulating a very large-scale SNN model (usually there are no batch sizes).

chaoming0625 commented 2 months ago

Partitioning on hidden states, and their interaction matrix is a simple model parallelization method.

Dr-Chen-Xiaoyu commented 2 months ago

Okay, I see.

By the way, I found that in the code of model definition, only change one line about the model state variable is enough for parallelization. No need to change the weights TrainVar with axis_names=['input' ,'hidden'] things.

# define variables
self.state = bp.init.variable(bm.zeros, batch_size, num_hid,  axis_names=['hidden'], batch_axis_name=['batch']) #<<<关键点

# define weights
self.win  = bm.TrainVar(bm.random.normal(0., 1., size=(num_in,  num_hid))) # 不用改
self.wrec = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_hid)))
self.wout = bm.TrainVar(bm.random.normal(0., 1., size=(num_hid, num_out))) 

Thanks again for the help👍👍👍