RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
585 stars 54 forks source link

Replace all `params_env_name ` with `FrozenDict` #14

Closed RobertTLange closed 3 years ago

RobertTLange commented 3 years ago

Some params dictionaries do specify the shapes of observation. Hence, when jitting we need to mark them as static_argnums. That in turn is only possible if the dictionary is immutable. I propose porting the flax FrozenDict and to provide a helper function called update_env_params(params, x_name, x_value), which unfreezes, changes and freezes the dictionary again.

In order to reduce dependencies, it may make sense to simply copy the file and use the same Apache License.

https://github.com/google/flax/blob/ac0f57419f32c9924e094e7e0dc82a15be228b5d/flax/core/frozen_dict.py

Go through all envs and update the parameter dictionaries.

RobertTLange commented 3 years ago

Adressed in #17.