TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
200 stars 33 forks source link

Support InverseFunctions.jl and ChangesOfVariables.jl #199

Closed oschulz closed 2 years ago

oschulz commented 3 years ago

To do

Context

Bijectors is really neat, but it's also kind of a heavy package with many dependencies (necessarily).

I have some variable transformation stuff in BAT.jl that I want to split out into a separate package, and I've recently been experimenting with some normalizing flow trafos - I'd like to make that kind of stuff compatible with Bijectors, but it would be to heavy as a dependency.

Would you guys (CC @devmotion) be interested in setting up a lightweight variable-trafo-interface package? Maybe just something like

function with_logabsdet_jacobian end

function with_logabsdet_jacobian(f::Base.ComposedFunction)(x)
    y_inner, ladj_inner = with_logabsdet_jacobian(f.inner, x)
    y, ladj_outer = with_logabsdet_jacobian(f.outer, y_inner)
    (y, ladj_inner + ladj_outer)
end

function inverse end

inverse(trafo::Base.ComposedFunction) = Base.ComposedFunction(inverse(f.inner), inverse(f.outer))

inverse(f) = inv(f)  # Until JuliaLang/julia#42421 is decided one way of the other

with_logabsdet_jacobian would be the equivalent of Bijectors.forward, while f(x) would be used to just run the transform without calculating LADJ. So the behavior would be

mytrafo(x) == y

with_logabsdet_jacobian(mytrafo, x) == (y, ladj)

inverse(trafo2∘ trafo1)(x) == inverse(trafo1)(inverse(trafo2)(x))

No autodiff or anything, that we'd leave to implementations like Bijectors.Bijector.

Long-term, it would be nice to get rid of inverse in favor of inv, if we can get inv(::ComposedFunction) into Base (just opened an issue in that regard: JuliaLang/julia#42421).

I'd volunteer to prototype a package (would be quite tiny, basically just the code above) if there's interest.

Edit: Replaced special argument type WithLADJ in proposal with function with_logabsdet_jacobian, as suggested by @devmotion

oschulz commented 2 years ago

ArraysOfArrays.jl could potentially provide the same, but I think there are a couple of caveats:

It could be more light-weight (and it uses Requires.jl).

I promised @cscherrer to get rid of that. :-)

but for a Batch you're really just interested in converting anything into a 1-dimensional collection.

Why I 1-dimensional one - wouldn't you lose information about which parts belong to which sample/event/...?

torfjelde commented 2 years ago

I promised @cscherrer to get rid of that. :-)

Nice:)

Why I 1-dimensional one - wouldn't you lose information about which parts belong to which sample/event/...?

I'm a bit uncertain what you mean/maybe I confused you. I meant that in our case, we want any representation which allows us to do batch[i] and get back what corresponds to a single input. How this is represented under the hood, the user shouldn't have to worry about (but the implementer of the transformation might want to specialize on different underlying representations, e.g. if it's represented as a higher-dim array then one might broadcast, etc.).

cscherrer commented 2 years ago

Thanks @oschulz for the ping :)

If there's refactoring in the works, here's my wishlist:

I've also been thinking of using StrideArrays for intermediate representations, since allocations can have a lot of overhead. This is still in the pondering stage, and I guess for a lightweight dependency is a bit much. But maybe there can be an argument giving the array type? Most important IMO is not to disallow it by design.

oschulz commented 2 years ago

If there's refactoring in the works, here's my wishlist

That I'll happily leave to the Bijectors team, I'll be busy getting VariateTransformation registration-ready (finally) :-) Initially I'll just do something like tpapp/TransformVariables.jl#85 so that code can use all kinds of transformations without depending on specific trafo packages directly.

oschulz commented 2 years ago

Why I 1-dimensional one - wouldn't you lose information about which parts belong to which sample/event/...? I'm a bit uncertain what you mean/maybe I confused you. I meant that in our case, we want any representation which allows us to do batch[i] and get back what corresponds to a single input.

Ah, now I get it - yes, of course! So there, batch could be an ArraysOfArrays.VectorOfSimilarVectors{<:Real} or a plain Vector{Vector{<:Real}} and so on, right?

but the implementer of the transformation might want to specialize on different underlying representations, e.g. if it's represented as a higher-dim array then one might broadcast

Yes, that what I do in BAT.jl at the moment, it uses ArraysOfArrays extensively (ArraysOfArrays was partially designed exactly for the batch-of-samples use case).

devmotion commented 2 years ago

So there, batch could be an ArraysOfArrays.VectorOfSimilarVectors{<:Real} or a plain Vector{Vector{<:Real}} and so on, right?

Or even more generally just an AbstractVector collection with arbitrary possibly non-Array elements - at least that's supported by the KernelFunctions API where ColVecs and RowVecs (optimizations for the vector of vector case with data as matrices) are used currently (I'm still looking forward to replacing them with EachCol and EachRow: https://github.com/JuliaLang/julia/pull/32310).

The Bijectors refactoring is discussed in https://github.com/TuringLang/Bijectors.jl/discussions/178 and worked on in https://github.com/TuringLang/Bijectors.jl/pull/183.

torfjelde commented 2 years ago

Or even more generally just an AbstractVector collection with arbitrary possibly non-Array elements

Exactly. We want to support arbitrary inputs.

@cscherrer I'll make a separate issue from your comment since it's very useful feedback, but here we're talking a bit more specifically about adoption of InverseFunctions.jl and ChangeOfVariables.jl rather than a general rewrite (as is being worked on in #183 as mentioned by David).

cscherrer commented 2 years ago

Thanks @torfjelde. We've discussed #183 , but the relationship of this issue (199) to that wasn't really clear to me. New issue sounds good :)

oschulz commented 2 years ago

Or even more generally just an AbstractVector collection with arbitrary possibly non-Array elements Exactly. We want to support arbitrary inputs.

ValueShapes has that - it allows you to view a flat matrix of real numbers as a vector of NamedTuple, for example (with an ArrayOfArrays in the middle).

devmotion commented 2 years ago

This is just one particular example - another example would be e.g. computing the kernel matrix on a vector of graphs. If the API just expects that collections of inputs are provided as an AbstractVector, the elements of the AbstractVector can be anything and it's not required that the underlying storage is an array itself.

oschulz commented 2 years ago

the elements of the AbstractVector can be anything and it's not required that the underlying storage is an array itself.

Oh sure - I didn't mean it should be limited in any way!

oschulz commented 2 years ago

@oschulz: I could prepare a small non-breaking PR to add add initial support for the InverseFunctions.jl

Not so small after all, but here it is: #212

oschulz commented 2 years ago

@willtebbutt with ChangesOfVariables and InverseFunctions support in LogExpFunctions and Bijectors not, ParameterHandling may be able to drop the explicit dependency on Bijectors now.

devmotion commented 2 years ago

I removed it already a while ago: https://github.com/invenia/ParameterHandling.jl/pull/42

oschulz commented 2 years ago

I removed it already a while ago: invenia/ParameterHandling.jl#42

Oh right, sorry, should have checked first - I just had this stuck as a "to do after" in my mind. Great!