danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
82 stars 10 forks source link

Add non_trainable #165

Closed danielward27 closed 2 months ago

danielward27 commented 2 months ago

A likely better way to mark parameters as non_trainable (at the array level rather than the tree level). At the tree level gives a simpler pytree definition, nicer printing etc, but the array level is more robust to avoiding missing attribute errors when e.g. manipulating pytrees.