brainpy / BrainPy

Brain Dynamics Programming in Python
https://brainpy.readthedocs.io/
GNU General Public License v3.0
491 stars 90 forks source link

bp.dnn.ToFlaxRNNCell is not working #663

Closed Dr-Chen-Xiaoyu closed 2 months ago

Dr-Chen-Xiaoyu commented 2 months ago

Hi, chaoming,

I am trying to use bp.dnn.ToFlaxRNNCell(), but some bugs pop out. I guess this is because some updating issue as to new versions of Flax. or maybe I misuse the function ?

Best, Xiaoyu Chen

import jax
import jax.numpy as jnp
import flax.linen as nn

import brainpy as bp
import brainpy.math as bm
bm.set_platform('cpu')
bm.set_mode(bm.training_mode)

print('bp version:', bp.__version__)
print('jax version:',jax.__version__)
print('flax version:',jax.__version__)
bp version: 2.6.0
jax version: 0.4.26
flax version: 0.4.26
cell = bp.dnn.ToFlaxRNNCell(bp.dyn.RNNCell(num_in=1, num_out=1,))

class myRNN(nn.Module):
    @nn.compact
    def __call__(self, x): # x:(batch, time, features)
        x = nn.RNN(cell)(x)  # Use nn.RNN to unfold the recurrent cell
        return x

model = myRNN()
model.init(jax.random.PRNGKey(0), jnp.ones([1,10,1])) # batch,time,feature
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
/data/xyc/codes/Tests/BrainPy/hessian.ipynb Cell 27 line 1
      7         return x
      9 model = myRNN()
---> 10 model.init(jax.random.PRNGKey(0), jnp.ones([1,10,1])) # batch,time,feature

    [... skipping hidden 9 frame]

/data/xyc/codes/Tests/BrainPy/hessian.ipynb Cell 27 line 6
      4 @nn.compact
      5 def __call__(self, x): # x:(batch, time, features)
----> 6     x = nn.RNN(cell)(x)  # Use nn.RNN to unfold the recurrent cell
      7     return x

    [... skipping hidden 2 frame]

File ~/anaconda/envs/env_bp_cpu/lib/python3.11/site-packages/flax/linen/recurrent.py:1066, in RNN.__call__(self, inputs, initial_carry, init_key, seq_lengths, return_carry, time_major, reverse, keep_order)
   1061   keep_order = self.keep_order
   1063 # Infer the number of batch dimensions from the input shape.
   1064 # Cells like ConvLSTM have additional spatial dimensions.
   1065 time_axis = (
-> 1066   0 if time_major else inputs.ndim - (self.cell.num_feature_axes + 1)
   1067 )
   1069 # make time_axis positive
   1070 if time_axis < 0:

    [... skipping hidden 1 frame]

File ~/anaconda/envs/env_bp_cpu/lib/python3.11/site-packages/flax/linen/recurrent.py:84, in RNNCellBase.num_feature_axes(self)
     81 @property
     82 def num_feature_axes(self) -> int:
     83   """Returns the number of feature axes of the RNN cell."""
---> 84   raise NotImplementedError

NotImplementedError: 
chaoming0625 commented 2 months ago

Yes, this is somehow the version issue. The flax has evolved.

chaoming0625 commented 2 months ago

Sorry, I am busy with other things. Maybe I can give a fix this weekend.

chaoming0625 commented 2 months ago

Moreover, I will give you a solution for parallerization this weekend. I am so sorry for the late response.

chaoming0625 commented 2 months ago

See #665