pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 240 forks source link

Masking a Mixture Model makes `component_` methods unavailable #1885

Open nstarman opened 1 month ago

nstarman commented 1 month ago

The MaskedDistribution object wraps models, but does not necessarily expose their public API. I'm raising this Issue to perhaps add support for the components_ methods.

Possible Options:

  1. Some clever method of auto-wrapping methods accessed by __getattr__.
  2. Have a MaskedMixtureDistribution class that is returned by MixtureDistribution.mask().
  3. instead of .mask() masking the Distribution, it returns the same mixture, but calling .mask() on all the component distributions.
fehiepsi commented 1 month ago

If you want to access properties of the base dist, the third option makes sense to me. But I think it is better access the components via d.base_dist.[attribute]

nstarman commented 1 month ago

I can push a PR for option 3, but I worry this breaks the API promise that .mask() -> MaskedDIstribution. This will start failing isinstance checks. IMO the best option is to do 2, where the MaskedMixtureDistribution is the intersection of the MixtureDistribution and MaskedDistribution API. I think 3 can be done in addition, which nicely ensures that the component distributions are masked even when accessed individually.

But I think it is better access the components via d.base_dist.[attribute]

That's what I do currently, but that loses the mask, so all the component_ methods aren't masked.

fehiepsi commented 1 month ago

You can mask those components just like the way you mask the mixture. I would recommend to keep the components as-is and only mask them when needed.

On Sat, Oct 12, 2024, 1:14 PM Nathaniel Starkman @.***> wrote:

I can push a PR for option 3!

But I think it is better access the components via d.base_dist.[attribute]

That's what I do currently, but that loses the mask, so all the component_ methods aren't masked.

— Reply to this email directly, view it on GitHub https://github.com/pyro-ppl/numpyro/issues/1885#issuecomment-2408630734, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABEEKVXY4OFTKWKLMCKFQWLZ3FKGNAVCNFSM6AAAAABPW7HSJ2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBYGYZTANZTGQ . You are receiving this because you commented.Message ID: @.***>

nstarman commented 1 month ago

Are you saying option 2 is best then?

fehiepsi commented 1 month ago

All of them modify the "default" behavior and introduce new "potentially more convenient in some usage case" behavior. I just meant that you can do d.base_dist.component...mask(...) to get a masked component without modifying the default behavior. For the requirement "component needs to be masked", an option could be to introduce mask_components method for the Mixture distribution - there it is flexible for you to introduce new behavior.