TuringLang / DistributionsAD.jl

Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
MIT License
151 stars 30 forks source link

Reparameterization attached to `Distribution` #4

Open torfjelde opened 4 years ago

torfjelde commented 4 years ago

Overview

Since a distributions has to be re-implemented here and the focus is on AD, I was wondering if it would be of any interested to add reparameterization to Distribution. In AD-context you usually want to work in ℝ (unconstrained) rather than constrained space, e.g. optimizing parameters for a Distribution.

A simple example is Normal(μ, σ). One might want to perform an maximum likelihood estimate (MLE) of μ and σ by gradient descent (GD). This requires differentiating the logpdf wrt. μ, σ and then updating the parameters of the Normal accordingly. But for the distribution to be valid we simultaneously need to ensure that σ > 0. Usually we accomplish this by instead differentiating the function

(μ, logσ) -> logpdf(Normal(μ, exp(logσ)), x)
# instead of
(μ, σ) -> logpdf(Normal(μ, σ), x)

The proposal is to also allow something like

reparam(μ, σ) = μ, exp(σ)
Normal(μ, logσ, reparam)

which in the MLE case allows us to differentiate

(μ, σ) -> logpdf(Normal(μ, σ, reparam), x)

Why?

As you can see, in the case of a univariate Normal this doesn't offer much advantage of the current approach. But the current approach is a subclass of what we can then support (by letting reparam equal identity) and I believe there certainly are cases where this is very useful:

for i = 1:num_steps nn = Dense(W, b)

d = MvNormal(μ, exp(σ)) # Diagonal MvNormal

x = rand(d)
y = nn(x)

# Do computation using `y`
# ...

Tracker.back!(...)

update!(W, b)
update!(μ, logσ)

end

VS.

Flux.@treelike MvNormal

nn = Dense(param(W_init), param(b_init)) d = MvNormal(param(μ_init), param(σ_init), (μ, σ) -> (μ, exp.(σ)))

for i = 1:num_steps x = rand(d) y = nn(x)

# Do computation using `y`
# ...

Tracker.back!(...)

update!(Flux.params(nn))  # more general
update!(Flux.params(d))

end


# Example implementation
```julia
using Distributions, StatsFuns, Random

abstract type ParameterizedDistribution{F, S, P} <: Distribution{F, S} where {P} end

# maybe?
transformation(::ParameterizedDistribution{F, S, P}) where {F, S, P} = P

struct NormalAD{T<:Real, P} <: ParameterizedDistribution{Univariate, Continuous, P}
    μ::T
    σ::T
end

NormalAD(μ::T, σ::T) where {T<:Real} = NormalAD{T, identity}(μ, σ)
NormalAD(μ::T, σ::T, f::Function) where {T<:Real} = NormalAD{T, f}(μ, σ)

# convenience; probably don't want to do this in an actual implementation
Base.identity(args...) = identity.(args)

function Distributions.logpdf(d::NormalAD{<:Real, P}, x::Real) where {P}
    μ, σ = P(d.μ, d.σ)
    z = (x - μ) / σ
    return -(z^2 + log2π) / 2 - log(σ)
end

function Distributions.rand(rng::AbstractRNG, d::NormalAD{T, P}) where {T, P}
    μ, σ = P(d.μ, d.σ)
    return μ + σ * randn(rng)
end
julia> # Standard: μ ∈ ℝ, σ ∈ ℝ⁺
       d1 = NormalAD(0.0, 1.0)
NormalAD{Float64,identity}(μ=0.0, σ=1.0)

julia> d2 = Normal(0.0, 1.0)
Normal{Float64}(μ=0.0, σ=1.0)

julia> x = randn()
-0.028232023381049923

julia> logpdf(d1, x) == logpdf(d2, x)
true

julia> # Real-valued: μ ∈ ℝ, σ ∈ ℝ using `exp`
       d3 = NormalAD(0.0, 0.0, (μ, σ) -> (μ, exp(σ)))
NormalAD{Float64,getfield(Main, Symbol("##3#4"))()}(μ=0.0, σ=0.0)

julia> logpdf(d3, x) == logpdf(d2, x)
true

julia> #  Real-valued: μ ∈ ℝ, σ ∈ ℝ using `softplus`
       d4 = NormalAD(0.0, invsoftplus(1.0), (μ, σ) -> (μ, softplus(σ)))
NormalAD{Float64,getfield(Main, Symbol("##9#10"))()}(μ=0.0, σ=0.541324854612918)

julia> logpdf(d4, x) == logpdf(d2, x)
true

Together with Tracker.jl

julia> using Tracker

julia> μ = param(0.0)
0.0 (tracked)

julia> σ = param(0.0)
0.0 (tracked)

julia> d_tracked = NormalAD(μ, σ, (μ, σ) -> (μ, exp(σ)))
NormalAD{Tracker.TrackedReal{Float64},getfield(Main, Symbol("##5#6"))()}(μ=0.0 (tracked), σ=0.0 (tracked))

julia> lp = logpdf(d_tracked, x)
-0.9193370567767668 (tracked)

julia> Tracker.back!(lp)

julia> Tracker.grad.((d_tracked.μ, d_tracked.σ))
(-0.028232023381049923, -0.9992029528558118)

julia> x = rand(d_tracked)
-1.6719800201542028 (tracked)

julia> Tracker.back!(x)

julia> Tracker.grad.((d_tracked.μ, d_tracked.σ))
(0.9717679766189501, -2.6711829730100147)

Alternative approach: wrap Distribution

An alternative approach is to do something similar to TransformedDistribution in Bijectors.jl where you simply wrap a distribution in the instance. Then you could require the user to provide a reparam method which takes what's returned from Distributions.params(d::Distribution) and applies the reparameterization correctly.

This requires signfinicantly less work, but isn't as nice nor as easy to extend/work with IMO.

willtebbutt commented 4 years ago

I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions? I had imagined something along the lines of an interface like

a_positive, a_unconstrained = positive(inv_link_or_link_or_whatever, a_positive_init)

Then we're just talking about the generic parameter handling / transformation problem, rather than anything inherently probabilistic.

Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered?

willtebbutt commented 4 years ago

Oops, didn't mean to close

torfjelde commented 4 years ago

I don't really understand why this has to be tied to a distributions library. Wouldn't it be more straightforward / useful to have this as an orthogonal thing that just plays nicely with distributions?

I see what you're saying, but I think it's just too closely related. And I think it's not far-fetched to say that "reparameterization of a Distribution is related to Distributions.jl"? Also in some cases it can simplify certain computations, e.g. entropy for a DiagMvNormal using exp to enforce positive-constraint on variance. And my main motivation is that you end up performing the transformations "behind the scenes" rather than the user having to do this in every method that needs it. You do it right once in the implementation of the Distribution and then no more. And the standard case is simply an instance of the more general reparametrizable Distribution, so the user who doesn't care doesn't have to care. Other than more work, I think the only downside is that it's more difficult to perform checks as to whether or not the parameters are valid.

Also, could we please try to think about how this plays with Zygote, rather than Tracker, as Tracker's day are numbered?

But I think Zygote also intends to support AD wrt. parameters of a struct, right? I can't find the issue right now, but I swear I saw @MikeInnes discussing something like this somewhere. If so, I think my argument using Tracker.jl still holds?

MikeInnes commented 4 years ago

I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load DistributionsAD on top of Distributions to get AD to work (not sure if that's the plan). Happy to look at support directly in Zygote, maybe via requires, if that's an option.

mohamed82008 commented 4 years ago

A few comments I have.

  1. Doing constrained optimization by transforming the constrained variables is just one way of doing constrained optimization. There are optimization algorithms that can efficiently handle box constraints, semidefinite constraints, linear constraints, etc.
  2. I think doing the re-parameterization of the constrained parameters at the optimization/differentiation layer, not the distribution layer, is the better approach in many cases at no loss of efficiency, e.g. x -> logpdf(Normal(1.0, exp(x)), 1.0) is pretty efficient.
  3. However, I also see the need for being able to construct a distribution using different parameters, e.g. precision vs covariance matrix, or directly using a triangular matrix which could be the Cholesky of the covariance or precision. I think these should be possible with multiple dispatch. Providing things like MvNormal(mu, Covariance(A)) or MvNormal(mu, Precision(A)). If A is a Cholesky we can also construct the PDMat directly. With these more efficient constructors, we get the triangular re-parameterization for free, e.g. L -> logpdf(MvNormal(mu, Covariance(Cholesky(L, 'L', 0))), x). I believe the distribution (re-)construction in this case should not allocate since we are not factorizing A.

Since we are discussing changes to Distributions, pinging @matbesancon.

torfjelde commented 4 years ago
  1. I think doing the re-parameterization of the constrained parameters at the optimization/differentiation layer, not the distribution layer, is the better approach in many cases at no loss of efficiency, e.g. x -> logpdf(Normal(1.0, exp(x)), 1.0) is pretty efficient.

That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. MvNormal assumes the covariance matrix is constant so the Cholesky decomp will be performed once upon construction).

It also doesn't solve the issue of "interoperability" with the parts of the ecosystem in which Distributions.jl is often used, e.g. with Tracker/Zygote. It of course works, but for larger models it can be quite a hassle compared to tying the parameters to the Distribution instance rather than keeping track of it through variables outside of the Distribution.

  1. However, I also see the need for being able to construct a distribution using different parameters, e.g. precision vs covariance matrix, or directly using a triangular matrix which could be the Cholesky of the covariance or precision. I think these should be possible with multiple dispatch. Providing things like MvNormal(mu, Covariance(A)) or MvNormal(mu, Precision(A)). If A is a Cholesky we can also construct the PDMat directly. With these more efficient constructors, we get the triangular re-parameterization for free, e.g. L -> logpdf(MvNormal(mu, Covariance(Cholesky(L, 'L', 0))), x). I believe the distribution (re-)construction in this case should not allocate since we are not factorizing A.

I think MvNormal already does this, no? But what is the difference between this and the more general approach of allowing "lazy" transformations like what this issue is proposing? It seems, uhmm, maybe a bit arbitrary to allow reparameterizations, but only for Cholesky and Precision? I understand you could do this for more reparameterizations, e.g. define Normal(μ, Exp(σ)) and so on, but this will require even more work and be less flexible than what this issue is proposing, right?

willtebbutt commented 4 years ago

That's true, but in multivariate cases you still cannot do inplace updates the parameters (though to allow this you'd have to take a slightly different approach to certain distributions than what Distributions.jl is currently doing, e.g. MvNormal assumes the covariance matrix is constant so the Cholesky decomp will be performed once upon construction)

I think this is one of the key aspects of this discussion. I'm personally more of a fan of the functional approach, but I appreciate that there are merits to both approaches. I'm not really sure which way the community is leaning here, perhaps @MikeInnes or @oxinabox can comment? If I remember correctly, Zygote's recommended mode of operation is now the functional style?

torfjelde commented 4 years ago

I started out preferring the more functional style, but have recently grown quite fond of the Flux approach. Granted, I've recently been using more neural networks where I think this approach is particularly useful.

Also, it's worth noting that with what's proposed here you can do both (which is why I like it!:) )

oxinabox commented 4 years ago

On an earlier point:

I haven't followed this issue carefully but (1) yes, Zygote supports structs well and (2) it'd be nice not to have to load DistributionsAD on top of Distributions to get AD to work (not sure if that's the plan). Happy to look at support directly in Zygote, maybe via requires, if that's an option.

Had this discussion with @matbesancon In context of ChainRules. My recollection is that While he happy about AD for derivatives, he absolutely does not what it in Distribution.jl ChainRules.jl (not ChainRulesCore) is just adding @requires for these cases. (Rewriting this for ChainRules is still a little way off, meet to continue improving struct support for that, I think)

mohamed82008 commented 4 years ago

I think MvNormal already does this, no? But what is the difference between this and the more general approach of allowing "lazy" transformations like what this issue is proposing? It seems, uhmm, maybe a bit arbitrary to allow reparameterizations, but only for Cholesky and Precision? I understand you could do this for more reparameterizations, e.g. define Normal(μ, Exp(σ)) and so on, but this will require even more work and be less flexible than what this issue is proposing, right?

@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Normal(μ, Exp(σ)) for example. For MvNormal, we can also do MvNormal(μ, Exp(Σ)) which internally also stores lazy wrapper of Σ and dispatches to efficient v' Exp(Σ)^-1 v and logdet(Exp(Σ)) where possible. For example, logdet(Exp(Σ)) = tr(Σ).

Dispatching on reparam in your proposal for efficient tricks like this is only possible if reparam itself uses the lazy Exp internally and we dispatch on Exp for logdet. So if we can avoid making our own AD types using the lazy wrapper approach directly, that would be better.

If we are talking modifying the distribution in-place (no AD), we can do that using the lazy function wrapper. Note that we always have to define the fields of the distribution according to its struct definition. So we have one of two scenarios:

  1. We tap into the inner most constructors for PDMat and MvNormal for example to define our distribution dist once while keeping the handle to Σ that we can modify in-place outside affecting the next logpdf(dist, x) result.
  2. We call an outer constructor that does copying, linear algebra, or call other functions that render our handles to Σ independent from the distribution struct returned.

This is fundamentally a constructor definition problem. It is a question of how we can construct the distribution while enabling in-place modification of the inputs. Lazy function wrappers take us some of the way. Note that at the end of the day we still need to satisfy the field type signature of the distribution struct, so we may need to modify the type parameters of the distribution struct to accept more generic matrix types like a lazy matrix-valued function which sub-types AbstractMatrix. Learning to live within those boundaries and pushing them where it makes sense to enable dispatch-based laziness seems like a more Julian approach to me than making 2 versions of the same struct, one persistent and one lazy.

So in summary, the anonymous function and dispatch-based laziness approach enables us to:

  1. Think about ways to make various functions more efficient, e.g. logdet,
  2. Avoid the need for an AD version of every distribution,
  3. Keep handles to the inputs passed to the outer constructor if we get laziness right, which enables in-place modification.

Note that at this point, it is not a question of whether we need arbitrary re-parameterization, just the API choice. I am leaning towards not having a struct for every distribution for AD purposes only, using anonymous functions and dispatch-based laziness to gain any efficiency and/or flexibility benefits. Ironically, we already implement an AD distribution for MvNormal here to workaround some Distributions-PDMats complexity. But for a long-term solution we should try to live within the boundaries of Distributions.jl and PDMats.jl.

Pinging @ChrisRackauckas in case he has opinions on this.

torfjelde commented 4 years ago

@torfjelde You can still do lazy transformations by multiple dispatch, like you said using Normal(μ, Exp(σ)) for example. For MvNormal, we can also do MvNormal(μ, Exp(Σ)) which internally also stores lazy wrapper of Σ and dispatches to efficient v' Exp(Σ)^-1 v and logdet(Exp(Σ)) where possible. For example, logdet(Exp(Σ)) = tr(Σ).

Yeah, I understood that but it would still require always building an explicit type Exp which could do this, in constrast to the user just passing in the exp function and we wrap every use of σ in this (this approach wouldn't just work for any case, but in univariate case it would be "one impl works for alll transformations").

But after reading your comment I realize we can just make a Lazy{exp}(σ) wrapper of σ and do the same thing as I wanted to do:) (You might have already realized this!) This is basically a "you know what you're doing"-type. Well, it's going to be rather annoying to have to specify different behavior on all combinations of the different parameters, .e.g. you want to apply log to μ and exp to σ you have to implement Normal{Log, Exp}, Normal{<:Real, Exp} and Normal{Log, <:Real} in addition to existing implementation. Granted, the same issue is a problem in what I'm proposing if you require a separate transform for each parameter and you want to do specific behavior for exp on σ.

I think I'm coming around to your suggestion!:) It still seems like making this compatible with current Distribution is going to be, uhmm, slightly challenging.

Dispatching on reparam in your proposal for efficient tricks like this is only possible if reparam itself uses the lazy Exp internally and we dispatch on Exp for logdet. So if we can avoid making our own AD types using the lazy wrapper approach directly, that would be better.

You could still do this when P is, say, the actual function exp though, right? But maybe this has some issues I'm not fully aware of.

Learning to live within those boundaries and pushing them where it makes sense to enable dispatch-based laziness seems like a more Julian approach to me than making 2 versions of the same struct, one persistent and one lazy.

"to where it makes sense" -> "to where we can" seems like a more accurate statement :upside_down_face:

mohamed82008 commented 4 years ago

You could still do this when P is, say, the actual function exp though, right? But maybe this has some issues I'm not fully aware of.

Well in your proposal IIUC, P is acting on all the arguments together not each one individually. So we don't really know that it is using exp on the covariance inside from its type only to do any magical specialization on P. This means we still need to rely on Exp for the dispatch-based lazy specialization of logdet for example.

"to where it makes sense" -> "to where we can" seems like a more accurate statement 🙃

True, but if we hit a wall, we can decide to temporarily branch off until the obstacle is removed. This is what we do now for MvNormal and arguably with this whole package.

But after reading your comment I realize we can just make a Lazy{exp}(σ) wrapper of σ and do the same thing as I wanted to do:)

Yes this is a nice generic way of defining lazy wrappers. Exp can be alias for Lazy{exp}.

torfjelde commented 4 years ago

I completely agree with the last comment:)

One thing though: this "wrapping"-approach means that if we want type-stability we'd have to allow all the parameters of a distribution to take on different types, e.g. Beta(a::T, b::T) can't be used since you might want to do Beta(a::Lazy{T, f1}, b::Lazy{T, f2}) where f1 and f2 are two different functions.

It still seems like the best approach, but worth noting that this might be a big hurdle to overcome as we'd basically need to re-define most distributions to accomodate something like this.

torfjelde commented 4 years ago

And I think something like the following works okay as a "default" where we just allow the type itself to specify how to handle the unconstrained-to-constrained transformation:

abstract type Constrained end

struct Unconstrained{T} <: Constrained
    val::T
end
value(c::Unconstrained) = c.val

Normal(μ::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), exp(value(σ)))

Could also do something like Unconstrained{T, F} where F is a callable. Then we can use

value(c::Unconstrained{T, F}) where {T, F} = F(c.val)

# when `F = identity` we have a default treatment
Normal(μ::Unconstrained{T, identity}, σ::Unconstrained{T, identity}) where {T} = Normal(value(μ), exp(value(σ)))

# in this case we have to assume that `value` takes care of the transformation
Normal(μ::Unconstrained{T}, σ::Unconstrained{T}) where {T} = Normal(value(μ), value(σ))

Need to think about this further, but doesn't seem like a horrible approach.