TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
196 stars 32 forks source link

Broadcasting? #145

Open cscherrer opened 3 years ago

cscherrer commented 3 years ago

There's an approach I've been thinking about for MeasureTheory.jl, and it could be nice to have it in Bijectors and avoid the type piracy :)

Say you have

julia> dist = MvNormal(zeros(2), ones(2))
DiagNormal(
dim: 2
μ: [0.0, 0.0]
Σ: [1.0 0.0; 0.0 1.0]
)

julia> f = PlanarLayer(2)
PlanarLayer{Array{Float64,1},Array{Float64,1}}([0.8187830879660829, -1.280004857469378], [-0.12761939163655306, -0.7795079813636753], [-0.025642876005072607])

As a Function, f takes arguments in R^2:

julia> f(zeros(2))
2-element Array{Float64,1}:
 0.009244220896803523
 0.010647769052268647

To get the pushforward, we need to do

julia> transformed(dist, f)
Bijectors.TransformedDistribution{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},PlanarLayer{Array{Float64,1},Array{Float64,1}},Multivariate}(
dist: DiagNormal(
dim: 2
μ: [0.0, 0.0]
Σ: [1.0 0.0; 0.0 1.0]
)

transform: PlanarLayer{Array{Float64,1},Array{Float64,1}}([0.8187830879660829, -1.280004857469378], [-0.12761939163655306, -0.7795079813636753], [-0.025642876005072607])
)

But conceptually, a distribution or measure is a kind of container (since it's a monad). This makes me wonder about using broadcasting for this. In that case we could just write

f.(dist)

I'm not sure if there might be unintended consequences of doing it this way, but it wouldn't be the first case of broadcasting off the beaten path - see https://github.com/JuliaApproximation/QuasiArrays.jl

Another option would be to just add a method

(f::Bijector)(dist::Distribution) = transformed(dist, f)

But if it works, the broadcasting approach seems more natural, from a mathematical perspective. What do you think?

devmotion commented 3 years ago

IMO the broadcasting machinery is quite heavy and all sorts of tricks and hacks are used in AD backends to (try to) deal with its complexity, so I try to avoid broadcasting whenever possible (and, e.g., just use map if dimensions don't have to be broadcasted, which in my experience often leads to simpler machine code as well). So personally I don't think there's a compelling reason for using broadcasting instead of transformed.

Regarding the second suggestion, I have mainly two concerns. IMO in general there should be exactly one supported way for doing things (to achieve a consistent API and to avoid confusion), so I think one should either use the function syntax or transformed but not both. A general problem with the function syntax might be that it might become confusing if bijectors can be applied to completely different objects with completely different meaning.

cscherrer commented 3 years ago

Thanks @devmotion . The first point makes sense to me. I disagree on the second - syntactic sugar is generally a good thing, ideally using the same machinery under the hood.

The use case I have in mind is, say you have (pseudo-PPL)

z ~ TDist(ν)
x = μ + σ * x

But maybe you want the x as a distribution. You could do

x ~ transformed(TDist(ν), Shift(μ) ∘ Scale(σ))

But to me this is much cleaner:

x ~ μ .+ σ .* TDist(ν)

Because TDist(ν) is a measure, the undotted form doesn't make sense - that would just be the scaled measure.

I think part of the weirdness to me is that the function-like thing is not in the first slot.

Another (maybe better than broadcasting) approach would be

x ~ transform(TDist(ν)) do x
        μ + σ * x
    end

The implementation could apply the function to an Identity, maybe something like how Measurements.jl works. I'm assuming this would either only work in special cases, or maybe only yield a Bijector for special cases.

Anyway, I guess I could do this outside of Bijectors, but it's still helpful talking it through :)

torfjelde commented 3 years ago

Though I appreciate the reasoning behind the first suggestion, I don't think we want to make an interface that makes sense only for people familiar with monads:) I didn't even make sense to me before you explained it.

For the second suggestion, I'm not immediately opposed but I also somewhat agree with what @devmotion is saying. I also think it's useful to remember that for PPLs a lot of the users will be completely new, even to the programming language being used. As a result, keeping the outwards-facing interface simple is of particular importance.

Also, I'm not sure the following is even preferable

(Shift(μ) ∘ Scale(σ))(dist)

to

transform(dist, Shift(μ) ∘ Scale(σ))

I do agree with your point about the function not being the first argument. We also considered pushforward, but this might be a bit unfriendly to people not familiar with basic measure theory. The reasoning is/was:

Unfortunately we went with the first one. Though I definitively prefer the second one for my personal use, I do think the transform is better for the simplicity reasons described above.

Also a bit unclear to me what you're saying when you suggested

x ~ transform(TDist(ν)) do x
        μ + σ * x
    end

This would require you to examine the expression inside the do block to figure out what the bijectors used are?

cscherrer commented 3 years ago

Though I appreciate the reasoning behind the first suggestion, I don't think we want to make an interface that makes sense only for people familiar with monads:) I didn't even make sense to me before you explained it.

I wouldn't lead with the monad thing in presenting it to a general audience, that's just a more technical justification. More intuitive is that a distribution or measure is a kind of collection.

Also a bit unclear to me what you're saying when you suggested

x ~ transform(TDist(ν)) do x
        μ + σ * x
    end

This would require you to examine the expression inside the do block to figure out what the bijectors used are?

Sure, but that's pretty easy. Start with an Identity bijector, and have

transform(f, dist) = transformed(dist, f(Identity()))

Then we just need some methods like

Base.:+(x::Real, f::Bijector) = Shift(x) ∘ f