cgarciae / treex

A Pytree Module system for Deep Learning in JAX
https://cgarciae.github.io/treex/
MIT License
215 stars 17 forks source link

Add `trainable_parameters` for extracting non-frozen parameters #47

Closed ptigwe closed 2 years ago

ptigwe commented 2 years ago

This adds a trainable_parameters filter which filters out parts of a given tx.Module which have been frozen. Also modified the docs to reflect this change and how it can be used when training a model with frozen sections.