adaptive-intelligent-robotics / QDax

Accelerated Quality-Diversity
https://qdax.readthedocs.io/en/latest/
MIT License
258 stars 42 forks source link

New style Jax RNG keys cannot be converted directly to Numpy random state seeds. #176

Open miltonllera opened 6 months ago

miltonllera commented 6 months ago

Jax recently started moving towards new style RNG-keys which entails changing calls to jax.random.PRNGKey for calls to jax.key. Additionally, seeds to numpy.random.RandomState from new-style keys (as in compute_cvt_centroids) must be first processed with jax.random.key_data otherwise we get a type error during conversion.

I am happy to write a pull request for this.

Aneoshun commented 6 months ago

Hi Milton,

Sounds great. Please, write de PR for this, we will be happy to review and accept it.

miltonllera commented 2 months ago

Sorry I've been slow on this.

I've submitted a PR, but I ran into some issues (as can be seen by the failed tests). The problem is the version of Haiku in requirements.txt (0.0.10) doesn't support new style rng keys. It is supported in the latest version (0.0.13) as I have tested this in my own machine. Updating the requirements should thus allow the examples/tests to remain the same.