TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Use ChangesOfVariables and InverseFunctions #212

Closed oschulz closed 2 years ago

oschulz commented 2 years ago

This PR implements the changes discussed in #199, by adding support for JuliaMath/ChangesOfVariables and JuliaMath/InverseFunctions. Both are very lightweight, low-dependency, low-bias packages designed to enable composability of packages that provide/implement or use variable transformation capabilities.

Specifically, this PR adds support for ChangesOfVariables.with_logabsdet_jacobian(::AbstractBijector, ::Any) (a direct equivalent of - and indeed modeled after - the current Bijectors.forward(::AbstractBijector, ::Any)) and InverseFunctions.inverse (a direct equivalent of the current Base.inv(::AbstractBijector)).

The following registered packages directly depend on Bijectors, currently: DifferentialEvolutionMCMC DynamicPPL Turing TuringModels Transits ParameterHandling MeasureTheory AdvancedVI AIBECS Soss

None of those dependent packages define subtypes of AbstractBijector, specialize Bijectors.forward or seem to specialize Base.inv (hope I didn't overlook any). So it seems resonable to deprecate Bijectors.forward(::AbstractBijector, ::Any) and Base.inv(::AbstractBijector) directly and replace all use of them inside of Bijectors.jl with ChangesOfVariables.with_logabsdet_jacobian and InverseFunctions.inverse.

The return type of with_logabsdet_jacobian is slightly different from forward though, it returns a Tuple instead of a NamedTuple{(:rv, :logabsdetjac)}. It seems that the only package that uses these fields is MeasureTheory, in a single place (not at all anymore on the current master branch, it seems). It's handled in the deprecation of forward in this PR.

Closes #199.

CC @torfjelde, @devmotion, @willtebbutt, @cscherrer

yebai commented 2 years ago

Cc @phipsgabler

devmotion commented 2 years ago

@oschulz it seems you forgot to push some changes?

oschulz commented 2 years ago

@oschulz it seems you forgot to push some changes?

Was still working on them. :-) Changes are pushed now.

oschulz commented 2 years ago

Don't run workflow yet, still fixing tests locally.

oschulz commented 2 years ago

Ready for review and CI (could you trigger the workflow, @devmotion ?)

I have one local test failure with Julia v1.7 in "test/transform.jl:151" (section with a comment "This should fail at the minute") but I get the same test failure with the current master branch, so it seems unrelated.

oschulz commented 2 years ago

Sorry I missed to many forwards initially, @devmotion !

I fixed a few things, should be ready for another CI run now.

oschulz commented 2 years ago

Should we also deprecate logabsdetjac? Currently, the default implementation of with_logabsdet_jacobian falls back on logabsdetjac, while ChangesOfVariables has with_logabsdet_jacobian as the primary function. JuliaMath/ChangesOfVariables.jl#3 (still undecided) would add a logabsdet_jacobian as an analog of logabsdetjac - but even if added, things would work the other way round, logabsdet_jacobian would fall back on with_logabsdet_jacobian.

The only package that currently seems add methods to logabsdetjac is Transits.jl, in a single place, to define the LADJ of Kipping13Transform. That could easily be changed to with_logabsdet_jacobian, this would also add support for ChangesOfVariables.jl to Transits.jl. @mileslucas, would that be Ok from your side? Soss.jl also defines a method of with_logabsdet_jacobian in a single place in the current release, but that seems gone on the master branch (@cscherrer?).

I'm not sure we can cleanly support both ways in Bijectors.jl (users defining either with_logabsdet_jacobian or logabsdetjac, and the other function then using a default method it if not specialized as well), at least not without ugly trickery. Removing the "primiary" status from logabsdetjac would be breaking - but until we do, users can't code a Bijector the "ChangesOfVariables way", I think.

devmotion commented 2 years ago

Maybe leave this for a separate PR as it seems to be a more fundamental change of Bijectors and, I assume, has to be benchmarked carefully? Maybe it would also benefit from an upstream definition of logabsdet_jacobian which seems to be another reason not to rush it.

In principle, however, I think one could use something like

logabsdetjac(b::AbstractBijector, x) = last(with_logabsdet_jacobian(b, x))
with_logabsdet_jacobian(b::AbstractBijector, x) = (b(x), logabsdetjac(b, x))

(potentially with some deprecations) to support implementations in both Bijectors and ChangesOfVariables style. But I guess it would be cleaner to do in a separate PR.

oschulz commented 2 years ago

In principle, however, I think one could use something like

logabsdetjac(b::AbstractBijector, x) = last(with_logabsdet_jacobian(b, x))
with_logabsdet_jacobian(b::AbstractBijector, x) = (b(x), logabsdetjac(b, x))

Won't that result in a stack overflow if neither is defined?

Maybe leave this for a separate PR

Sound good - I think @torfjelde was planning a deeper overhaul anyway?

devmotion commented 2 years ago

Won't that result in a stack overflow if neither is defined?

Yes, it does. But I assumed this would be fine - for a user or developer it's an indication that you should define (at least) one of these methods.

cscherrer commented 2 years ago

Soss.jl also defines a method of with_logabsdet_jacobian in a single place in the current release, but that seems gone on the master branch (@cscherrer?).

IIRC this had been an optional dependency, but it wasn't working out because of exported Distributions, and it didn't allow manifolds to be represented an embeddings from lower-dimensional spaces. I think I'll need to wait for https://github.com/TuringLang/Bijectors.jl/pull/183 before I can use it.

oschulz commented 2 years ago

I get an intermittent (but frequent) test failure in "test/transform.jl:151": For

dist = Dirichlet([1000 * one(Float64), eps(Float64)])
x = rand(dist) == [0.9999999999999998, 0.0]

The test

@test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing

fails, due to

julia> using Bijectors, LinearAlgebra

julia> b = Bijectors.SimplexBijector{false}();
julia> x = [1.0 - eps(typeof(1.0)), 0.0];
julia> x2 = [1.0, 0.0];

julia> b(x), b(x2)
([35.35050620855721, 2.220446049250313e-16], [36.04365338911715, 0.0])

julia> logabsdetjac(b, x), logabsdet(Bijectors.jacobian(b, x))[1]
(36.04365338911715, 35.35050620855721)

julia> logabsdetjac(b, x2), logabsdet(Bijectors.jacobian(b, x2))[1]
(36.04365338911715, 36.04365338911715)

I get the same one the current master branch, I don't think it's related to this PR.

I put in a workaround in the test for now.

oschulz commented 2 years ago

@devmotion one thing I hadn't realized until now is that Bijectors allows for forward(b, multiple_xs_as_matrix), returning a matrix of result values and a vector of LADJ-values. Our current definition of ChangesOfVariables.with_logabsdet_jacobian doesn't really allow for that, I think. What's your take on this?

devmotion commented 2 years ago

Our current definition of ChangesOfVariables.with_logabsdet_jacobian doesn't really allow for that, I think. What's your take on this?

That this will be resolved by the refactoring and use of Batch or something similar (hopefully): https://github.com/TuringLang/Bijectors.jl/discussions/178

oschulz commented 2 years ago

That this will be resolved by the refactoring and use of Batch or something similar (hopefully): #178

Ok, then - I guess this break from contract isn't actually harmful for now, then. LADJ's are only ever added/subtracted, which works with a vector just as well as with a Real. ;-)

I don't understand where the

Classic: Error During Test at /home/runner/work/Bijectors.jl/Bijectors.jl/test/ad/utils.jl:35
  Test threw exception
  Expression: ≈(ReverseDiff.gradient(f, x), finitediff, rtol=rtol, atol=atol)
  UndefVarError: inverse not defined

test error(s) are coming from, though. How can inverse not be defined inside ReverseDiff.gradient(f, x) if f(x) works perfectly fine?

devmotion commented 2 years ago

How can inverse not be defined inside ReverseDiff.gradient(f, x) if f(x) works perfectly fine?

The optional ReverseDiff and Tracker support are in different submodules in src/compat/reversediff.jl and src/compat/tracker.jl, to keep ReverseDiff.TrackedArray/Real and Tracker.TrackedArray/Real separate. I guess you have to load Bijectors.inverse (or InverseFunctions.inverse) in the submodule(s).

oschulz commented 2 years ago

Will try.

The other persistent test failure can be traced back to this MWE

using Bijectors, Zygote
m = Bijectors.PartitionMask(3, [1], [2])
g(x) = Bijectors.Scale(1.0)
cl = Bijectors.Coupling(g, m)
x = [1.0, 2.0, 3.0]
dy = [4.0, 5.0, 6.0]
Zygote.pullback(cl, x)[2](dy)

resulting in

ERROR: MethodError: no method matching zero(::Type{Nothing})

This seems unrelated to this PR, though, happens on the current master branch of Bijectors as well.

devmotion commented 2 years ago

There were quite many changes in Zygote recently, maybe they uncovered a bug in Bijectors or introduced one in Zygote.

oschulz commented 2 years ago

Yes, that was my suspicion as well.

oschulz commented 2 years ago

I guess you have to load Bijectors.inverse (or InverseFunctions.inverse) in the submodule(s).

I tried (just pushed it here), but I still get inverse not defined.

oschulz commented 2 years ago

Here's a more minimal MWE for the Zygote problem:

using Bijectors, Zygote
using Bijectors: PartitionMask, combine

m = PartitionMask(3, [1], [2])
a, b, c = ([1.0], [2.0], [3.0])

y = combine(m, a, b, c)
Zygote.pullback(combine, m, a, b, c)[2](y)

results in

ERROR: LoadError: MethodError: no method matching zero(::Type{Nothing})
oschulz commented 2 years ago

Ok, added an rrule for combine(m::PartitionMask, x_1, x_2, x_3), that took care of the Zygote errors.

oschulz commented 2 years ago

Ok, pkg> test Bijectors passed locally for me now. Could someone trigger the CI workflow?

oschulz commented 2 years ago

Yay, tests green. :-)

oschulz commented 2 years ago

fix the deprecation warnings that show up in the tests

On it.

torfjelde commented 2 years ago

Sorry for being a bit awol; past week has been busy, preparing to go home for Christmas. But just had a quick look and this looks great! Thank you @oschulz !

It seems like @devmotion has already done a proper review of this, so tbh I don't have any comments (beyond his latest on deprecation tests and bumping the verison number) :+1:

So feel free to approve it when you're happy @devmotion . I'll be travelling tomorrow, so won't be able to have a look at this again until Tues at the earliest.

oschulz commented 2 years ago

But just had a quick look and this looks great! Thank you @oschulz !

Thanks! It did get a log bigger than I had expected, initially. :-)

In the meantime, can you update the version number and fix the deprecation warnings

Version number is up and I think I finally eliminated the last remaining deprecation warnings. Let's see if the tests go through clean this time.

oschulz commented 2 years ago

Ok, looks clean.

oschulz commented 2 years ago

Ok, inverse and with_logabsdet_jacobian are re-exported.

oschulz commented 2 years ago

Thanks for all the comments and suggestions!

torfjelde commented 2 years ago

Indeed, thank you so much @oschulz ! Great stuff:)

oschulz commented 2 years ago

@devmotion if it's fine with you in principle, I'd draft a PR for distributions to "lift" TransformedDistribution from Bijectors into Distributions. I think it will be a very valuable tool to have, and with InverseFunctions and ChangesOf Variables (and support for them in Bijectors) in place we can make it available without a dependency on Bijectors (I think, haven't drafted the code yet).

devmotion commented 2 years ago

Sounds reasonable - currently we handle only the special case of affine transformations (limited but hopefully soon in a bit more general way) so I think it would be a valuable addition.

oschulz commented 2 years ago

Sounds reasonable - currently we handle only the special case of affine transformations

I do have an idea how we can (hopefully) support arbitrary variate types (I definitely want ValueShapes.NamedTupleDist to work with this). I will probably need to include defining Random.gentype for distributions (I think that was already under discussion?) and I may need to at least define new VariateForm for struct types. @devmotion if that's fine with you at least in principle, I would make a concrete draft PR as a basis for discussing details.

devmotion commented 2 years ago

It would be nice to keep changes as minimal as possible (but e.g. definitions so general that they allow such use cases without breaking changes later on), such that the PR does not become too large and it is less likely that discussions diverge and/or not focus on the main changes.

devmotion commented 2 years ago

In particular the Random stuff is a very sensitive area where people tend to have strong opinions :smile:

oschulz commented 2 years ago

Understood - I'll try to make something compact.

devmotion commented 2 years ago

BTW the current design proposal for eltype (but on purpose not including gentype etc.) is https://github.com/JuliaStats/Distributions.jl/pull/1433. But maybe it is sufficient to just call regular rand of the untransformed distribution for sampling?

oschulz commented 2 years ago

But maybe it is sufficient to just call regular rand of the untransformed distribution for sampling

Yes. The tricky bit will be inferring the VariateForm of the transformed distribution. Base._return_type(f, (Random.gentype(orig_dist),)) will allow us to infer that in many cases, falling back on running the trafo on a single rand value and determining the VariateForm from the result (in the spirit of what Broadcast does to infer the return type and what it does when it can't).