danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
101 stars 14 forks source link

Reduce coupling #189

Closed danielward27 closed 2 weeks ago

danielward27 commented 2 weeks ago

Small change so that utils.py does not depend on wrappers.py.

Calls to get_ravelled_pytree_constructor will now need to explicitly pass the *args and **kwargs for partitioning parameters (usually setting is_leaf=lambda leaf: isinstance(leaf, wrappers.NonTrainable).