pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.59k stars 987 forks source link

Port implementation of SimplexToOrderedTransform from numpyro #3320

Closed peblair closed 9 months ago

peblair commented 9 months ago

This pull request adds support to Pyro for SimplexToOrderedTransform. The implementation used is a port of the numpyro implementation of the same class.

I was working on implementing a model based on this numpyro guide using Pyro, and this was the one thing which was necessary to reimplement.

Because this class is meant to be used on values greater than zero, I have additionally added the ability to specify the family of distribution which should be used in the transformation tests; this new transformation uses a Dirichlet distribution for its tests.

peblair commented 9 months ago

@fritzo Thanks for giving this a look! The issues you flagged should be fixed now. Let me know if there's anything else this needs.