Some params dictionaries do specify the shapes of observation. Hence, when jitting we need to mark them as static_argnums. That in turn is only possible if the dictionary is immutable. I propose porting the flaxFrozenDict and to provide a helper function called update_env_params(params, x_name, x_value), which unfreezes, changes and freezes the dictionary again.
In order to reduce dependencies, it may make sense to simply copy the file and use the same Apache License.
Some
params
dictionaries do specify the shapes of observation. Hence, when jitting we need to mark them asstatic_argnums
. That in turn is only possible if the dictionary is immutable. I propose porting theflax
FrozenDict
and to provide a helper function calledupdate_env_params(params, x_name, x_value)
, which unfreezes, changes and freezes the dictionary again.In order to reduce dependencies, it may make sense to simply copy the file and use the same Apache License.
https://github.com/google/flax/blob/ac0f57419f32c9924e094e7e0dc82a15be228b5d/flax/core/frozen_dict.py
Go through all envs and update the parameter dictionaries.