luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
738 stars 62 forks source link

[Bug] Error in initialize carry function for ppo_rnn #29

Open corentinlger opened 3 months ago

corentinlger commented 3 months ago

Hello, I wanted to use ppo_rnn.py and encountered an on error when using the algorithm. It was about the input arguments of the initialize_carry function to create the carry for the GRUCell.

I think this is due to an update of Flax RNNs API :

Code to reproduce the error :

import jax 
from purejaxrl.ppo_rnn import make_train

config = {
    "LR": 2.5e-4,
    "NUM_ENVS": 4,
    "NUM_STEPS": 128,
    "TOTAL_TIMESTEPS": 5e5,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 4,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.01,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "CartPole-v1",
    "ANNEAL_LR": True,
}

rng = jax.random.PRNGKey(42)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)

Error message :

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 24
     22 rng = jax.random.PRNGKey(42)
     23 train_jit = jax.jit(make_train(config))
---> 24 out = train_jit(rng)

    [... skipping hidden 11 frame]

File[ ~/Desktop/code/purejaxrl/purejaxrl/ppo_rnn.py:121](about:blank), in make_train.<locals>.train(rng)
    114 rng, _rng = jax.random.split(rng)
    115 init_x = (
    116     jnp.zeros(
    117         (1, config["NUM_ENVS"], *env.observation_space(env_params).shape)
    118     ),
    119     jnp.zeros((1, config["NUM_ENVS"])),
    120 )
-->[ 121](about:blank) init_hstate = ScannedRNN.initialize_carry(config["NUM_ENVS"], 128)
    122 network_params = network.init(_rng, init_hstate, init_x)
    123 if config["ANNEAL_LR"]:

File[ ~/Desktop/code/purejaxrl/purejaxrl/ppo_rnn.py:41](about:blank), in ScannedRNN.initialize_carry(batch_size, hidden_size)
     38 @staticmethod
     39 def initialize_carry(batch_size, hidden_size):
     40     # Use a dummy key since the default state init fn is just zeros.
--->[ 41](about:blank)     return nn.GRUCell.initialize_carry(
     42         jax.random.PRNGKey(0), (batch_size,), hidden_size
     43     )

File[ ~/Desktop/code/purejaxrl/venv/lib/python3.10/site-packages/flax/linen/recurrent.py:614](about:blank), in GRUCell.initialize_carry(self, rng, input_shape)
    603 @nowrap
    604 def initialize_carry(self, rng: PRNGKey, input_shape: Tuple[int, ...]):
    605  """Initialize the RNN cell carry.
    606 
    607  Args:
   (...)
    612    An initialized carry for the given RNN cell.
    613  """
-->[ 614](about:blank)   batch_dims = input_shape[:-1]
    615   mem_shape = batch_dims + (self.features,)
    616   return self.carry_init(rng, mem_shape, self.param_dtype)

TypeError: 'int' object is not subscriptable

If this is indeed the error, do you want me to do a PR to fix it ?

smokbel commented 1 month ago

Another issue that occurs (after fixing this one) is:

AttributeError: DynamicJaxprTracer has no attribute features

coming from the GRUCell initialize_carry function in flax, when trying to access its features attribute within a traced object.

corentinlger commented 1 month ago

Yes I think it's the second point I mention in the issue.

You can maybe try this fixed version of the file (it worked 1 month ago) : https://github.com/corentinlger/purejaxrl/blob/fix_ppo_rnn/purejaxrl/ppo_rnn.py