TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

`OrderedBijector` #186

Closed torfjelde closed 3 years ago

torfjelde commented 3 years ago

Something similar to ordered in Stan has been requested on several occasions, e.g. https://github.com/TuringLang/Turing.jl/discussions/1535.

This PR introduces OrderedBijector + Bijectors.ordered(d::Distribution) to address this.

I also took the liberty of moving the Tracker-compat stuff into a submodule, similar to the other AD-packages, and made overloads clearer by not importing the functions. It was quite difficult to figure out what was missing when I made it a submodule when everything was imported.

torfjelde commented 3 years ago

I know that Stan uses this transform but already when I read the Stan docs a while ago I wondered how this compares to other possible alternatives (e.g. with log1pexp or completely different approaches?).

No, I haven't considered alternative approaches. I'm all for it though, it's just that I think it's more important to at least have support for this, and then we can experiment with different approaches (unless we already know one is better than the other).

devmotion commented 3 years ago

Sure, at this stage I was mostly curious if it is known that another approach would be better.

torfjelde commented 3 years ago

Happy with this now @devmotion ? Only thing left is version-bump.

torfjelde commented 3 years ago

There's a bug with DistributionsAD.jl for ReverseDiff.jl again :confused:

Can we move the DistSpec tests to DistributionsAD.jl?

Now that link, invlink, and logpdf_with_trans are mere wrappers around the bijectors, if logpdf is AD-compatible for a distribution d, then Bijectors.jl ensuring AD-compat for its transform and logabsdetjac is sufficient to ensure logpdf_with_trans being AD-compatible. That is, Bijectors.jl should be responsible for AD-compat of the bijectors and nothing more.

And the times where PRs to Bijectors.jl are delayed because of DistributionsAD.jl are a bit too many, hence moving this over to DistributionsAD would be helpful in that regard too.

EDIT: I'm working on doing this right now. It seems the only logpdf_with_trans for which the above argument doesn't hold is

https://github.com/TuringLang/Bijectors.jl/blob/edbb5603004e9567f59f6a659b18923fc66b18ec/src/Bijectors.jl#L209-L217

but even so it seems more than worth it.

torfjelde commented 3 years ago

Good to go?

torfjelde commented 3 years ago

Should be good now :+1: