JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.11k stars 415 forks source link

Uncompatibility rand(MvNormal()) and AutoDiff #813

Closed theogf closed 5 years ago

theogf commented 5 years ago

Hello it is unfortunately not possible to use automatic differentation with (at least) the MvNormal distribution. The following code will fail at rand(p) due to a wrong conversion to Float64

using Distributions, ForwardDiff
function foo(mu)
    p = MvNormal(mu)
    sum(rand(p) for _ in 1:100)
end
ForwardDiff.gradient(foo,rand(10))
matbesancon commented 5 years ago

@theogf PR welcome on that

theogf commented 5 years ago

So I found the source of the error but it's in common.jl so not sure if the change would be breaking : https://github.com/JuliaStats/Distributions.jl/blob/817bd83326f9d562ec88c0e782e60d25d64862ad/src/common.jl#L51 Says that whatever the type of the distribution is eltype will return Float64 for a continuous function. Since eltype is called many times for the MvNormal sampling, one always end up with Float64 samples. I don't know the dependency of other distributions on eltype but a quick fix would be overloading it for MvNormal types

matbesancon commented 5 years ago

Yes this should be overloaded for (almost?) all distributions

On Wed, May 8, 2019, 11:58 Théo Galy-Fajou notifications@github.com wrote:

So I found the source of the error but it's in common.jl so not sure if the change would be breaking :

https://github.com/JuliaStats/Distributions.jl/blob/817bd83326f9d562ec88c0e782e60d25d64862ad/src/common.jl#L51 Says that whatever the type of the distribution is eltype will return Float64 for a continuous function. Since eltype is called many times for the MvNormal sampling, one always end up with Float64 samples. I don't know the dependency of other distributions on eltype but a quick fix would be overloading it for MvNormal types

— You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/JuliaStats/Distributions.jl/issues/813#issuecomment-490425931, or mute the thread https://github.com/notifications/unsubscribe-auth/AB2FDMRLU6IZHHZP3HYLPWLPUKP3FANCNFSM4GQEEJNA .

matbesancon commented 5 years ago

@theogf did #882 close this? The error I get is because the function does not return a scalar, which ForwardDiff.gradient requires. It seems fixed with:

julia> function foo(mu)
           p = MvNormal(mu)
           sum(sum(rand(p) for _ in 1:100))
       end
andreasnoack commented 3 years ago

Is it really reasonable to expect rand to be differentiable? Making the sampler structs parametric introduces a lot of complications since the samplers are usually written for a specific precision. It also touches on the issue of the variate type vs the parameter type. For most distributions, the variate type doesn't follow the parameter type.

mschauer commented 3 years ago

Yes, I think this could be important, differentiability of a + b*randn(rng) conditional on the rng state is practical (keyword normalising flows)

devmotion commented 3 years ago

Related: https://github.com/TuringLang/DistributionsAD.jl/issues/123

I think in many cases it is not important to implement the samplers in a differentiable way but it would be useful to add custom adjoints, probably based on ChainRulesCore.

andreasnoack commented 3 years ago

But does this generalize beyond location/scale? I don't think e.g. GammaGDSampler is differentiable in the shape parameter. Notice that https://github.com/JuliaStats/Distributions.jl/issues/1024 is a consequence of the complexity that a parametric struct introduces.

mschauer commented 3 years ago

Every y = model(rng, x) where you see y and want to know x awakes the desire to take x derivatives of model, that is not restricted to means and Gaussianity. So if it's possible (implying differentiable dependence on x for fix Random.seed!) you would like to allow it.

andreasnoack commented 3 years ago

It would be great if we could figure out a way to handle this that isn't generally broken and only sometimes (even if very often) works. Recall that the GammaGDSampler is currently broken for some inputs because of the type parameters and it didn't take too much effort to find an example where the variates aren't continuous in the shape parameter

julia> _seed
165

julia> tmp = [rand(Random.MersenneTwister(_seed), Distributions.GammaGDSampler(Gamma(_s, 0.1))) for _s in s];

gammasampler

I generally think you should be very reluctant to allow loose signatures for methods that exploits details about floating point numbers such as their precision.

I'm wondering if instead, we could define AD rules for rand methods. They could (hopefully?) be restricted to the parameters for which we know the variates are differentiable given a seed. For the Gamma it might actually be a problem the distribution only has a single type parameter since it excludes the possibility of restricting Duals to the scale parameter.

devmotion commented 3 years ago

I'm wondering if instead, we could define AD rules for rand methods.

This is what I had in mind when I said that it might be better to add custom adjoints/AD rules instead of trying to make the sampling algorithm itself differentiable.

For the Gamma it might actually be a problem the distribution only has a single type parameter since it excludes the possibility of restricting Duals to the scale parameter.

This would only be a problem for AD systems that operate with special number types such as Dual but not for e.g. Zygote.

mschauer commented 3 years ago

That single type parameter shouldn’t pose a problem, one can promote the other parameter to dual too.

andreasnoack commented 3 years ago

That single type parameter shouldn’t pose a problem, one can promote the other parameter to dual too.

The point I tried to make was that you'd have to restrict the argument type for the shape parameter to Float64 and not allow Duals.

mschauer commented 3 years ago

Its not differentiable in the shape but it can run on duals with partials equal 0? I don’t think it is responsibility of a package to prevent dual inputs to nondifferentable functions

andreasnoack commented 3 years ago

Let me elaborate: roughly speaking, there are two kinds of floating point methods

  1. "Core", such as (+)(::Float64, Float64), exp(::Float32), and exp(::Float64).
  2. "Compositions", such as logistic(::Real), norm(::Vector{<:Number}) and most user defined methods

The first groups exploits details of the computer representation of the numbers such as calling an LLVM intrinsic or a libm function with a specific precision. However, the group also includes native Julia implementation that evaluates series to a specific precision. So many of the definitions in SpecialFunctions fall into this group. I'm arguing that some/most samplers fall are part of this group as well. The GammaGDSampler is such an example, see e.g. https://github.com/JuliaStats/Distributions.jl/blob/8c7f400bd8b9ffe3773f2ee52501ee75b0fba3cd/src/samplers/gamma.jl#L43-L55 and https://github.com/JuliaStats/Distributions.jl/blob/8c7f400bd8b9ffe3773f2ee52501ee75b0fba3cd/src/samplers/gamma.jl#L65-L74. My argument is that whenever methods evaluate to a specific precision like that then we'd have to restrict the signature to ensure correctness.

The second group is composed out of "Core" methods and I completely agree that such definitions should have as loose a signature as possible to allow for as many number types as possible. Regarding AD then we need rules for the "Core" group for AD to work and the beauty is then that AD automatically works for the second group provided that we have used sufficiently loose signatures in the method definitions.

What we are currently doing is that we consider the sampler a "Composition" method. I'm arguing that it's not sound and that we'd have to make it a "Core" method and define AD rules for it. Specifically, we only need to consider the version for scale==1 and "Core" method. The scaling can be handled as a "Composition" which is why I said that it might be better to split the type parameter in two.

mschauer commented 3 years ago

I think we are on one page, I agree with your line of reasoning, only in isolation I wouldn't restrict

julia> f(v) = (v*@horner(v, 
                                         0.333333333, 
                                         -0.249999949, 
                                         0.199999867, 
                                         -0.1666774828, 
                                         0.142873973, 
                                         -0.124385581, 
                                         0.110368310, 
                                         -0.112750886, 
                                         0.10408986))

to Float64 because it does exactly the right thing for e.g.

julia> using IntervalArithmetic

julia> f(1.0..(1+eps()))
[0.236851, 0.236852]

at the right precision.