Open corentinlger opened 3 months ago
Another issue that occurs (after fixing this one) is:
AttributeError: DynamicJaxprTracer has no attribute features
coming from the GRUCell initialize_carry function in flax, when trying to access its features
attribute within a traced object.
Yes I think it's the second point I mention in the issue.
You can maybe try this fixed version of the file (it worked 1 month ago) : https://github.com/corentinlger/purejaxrl/blob/fix_ppo_rnn/purejaxrl/ppo_rnn.py
Hello, I wanted to use ppo_rnn.py and encountered an on error when using the algorithm. It was about the input arguments of the
initialize_carry
function to create the carry for the GRUCell.I think this is due to an update of Flax RNNs API :
initialize_carry
are now (rng, input_shape) instead of (rng, batch_size, hidden_size)num_features
as an argument ofRNNCells
(GRUCell here)Code to reproduce the error :
Error message :
If this is indeed the error, do you want me to do a PR to fix it ?