Open juanitorduz opened 11 months ago
Hi @juanitorduz, if you need this feature, please feel free to put it in contrib.module
. I guess you can mimic random_flax_module
for an implementation. If you need to clarify something, please leave a comment in this issue thread.
Great! Makes sense. Thank you @fehiepsi ! I'll give it a try in the upcoming months!
I've been using this in my package flowjax for registering parameters for equinox modules.
def register_params(
name: str,
model: PyTree,
filter_spec: Callable | PyTree = eqx.is_inexact_array,
):
"""Register numpyro params for an arbitrary pytree.
This partitions the parameters and static components, registers the parameters using
numpyro.param, then recombines them. This should be called from within an inference
context to have an effect, e.g. within a numpyro model or guide function.
Args:
name: Name for the parameter set.
model: The pytree (e.g. an equinox module, flowjax distribution/bijection).
filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a
callable `leaf -> bool`, or a PyTree with prefix structure matching `dist`
with True/False values. Defaults to `eqx.is_inexact_array`.
"""
params, static = eqx.partition(model, filter_spec)
if callable(params):
# Wrap to avoid special handling of callables by numpyro. Numpyro expects a
# callable to be used for lazy initialization, whereas in our case it is likely
# a callable module we wish to train.
params = numpyro.param(name, lambda _: params)
else:
params = numpyro.param(name, params)
return eqx.combine(params, static)
It's not particularly well tested, and I'm not familiar with the implementations for other frameworks, but maybe it's another useful reference. After training I just use eqx.combine(trained_params, model)
to retrieve the trained module.
Thank you @danielward27 ! This will be a great entry point! (I am planning to tackle this sometime in February)
It would be nice to have
equinox_module
andrandom_equinox_module
model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as Equinox seems to be in quite active development.Would this be a good addition?
I could give it a shot in the upcoming months but I will need some guidance :) Still, I am also happy if a more experienced dev wants to give it a go. XD.