pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 234 forks source link

Equinox models integration #1709

Open juanitorduz opened 9 months ago

juanitorduz commented 9 months ago

It would be nice to have equinox_module and random_equinox_module model functions in https://github.com/pyro-ppl/numpyro/blob/master/numpyro/contrib/module.py as Equinox seems to be in quite active development.

Would this be a good addition?

I could give it a shot in the upcoming months but I will need some guidance :) Still, I am also happy if a more experienced dev wants to give it a go. XD.

fehiepsi commented 9 months ago

Hi @juanitorduz, if you need this feature, please feel free to put it in contrib.module. I guess you can mimic random_flax_module for an implementation. If you need to clarify something, please leave a comment in this issue thread.

juanitorduz commented 9 months ago

Great! Makes sense. Thank you @fehiepsi ! I'll give it a try in the upcoming months!

danielward27 commented 8 months ago

I've been using this in my package flowjax for registering parameters for equinox modules.


def register_params(
    name: str,
    model: PyTree,
    filter_spec: Callable | PyTree = eqx.is_inexact_array,
):
    """Register numpyro params for an arbitrary pytree.

    This partitions the parameters and static components, registers the parameters using
    numpyro.param, then recombines them. This should be called from within an inference
    context to have an effect, e.g. within a numpyro model or guide function.

    Args:
        name: Name for the parameter set.
        model: The pytree (e.g. an equinox module, flowjax distribution/bijection).
        filter_spec: Equinox `filter_spec` for specifying trainable parameters. Either a
            callable `leaf -> bool`, or a PyTree with prefix structure matching `dist`
            with True/False values. Defaults to `eqx.is_inexact_array`.

    """
    params, static = eqx.partition(model, filter_spec)
    if callable(params):
        # Wrap to avoid special handling of callables by numpyro. Numpyro expects a
        # callable to be used for lazy initialization, whereas in our case it is likely
        # a callable module we wish to train.
        params = numpyro.param(name, lambda _: params)
    else:
        params = numpyro.param(name, params)
    return eqx.combine(params, static)

It's not particularly well tested, and I'm not familiar with the implementations for other frameworks, but maybe it's another useful reference. After training I just use eqx.combine(trained_params, model) to retrieve the trained module.

juanitorduz commented 8 months ago

Thank you @danielward27 ! This will be a great entry point! (I am planning to tackle this sometime in February)