Small change so that utils.py does not depend on wrappers.py.
Calls to get_ravelled_pytree_constructor will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually setting is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable).
Small change so that utils.py does not depend on wrappers.py.
Calls to
get_ravelled_pytree_constructor
will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually settingis_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable)
.