TuringLang / DistributionsAD.jl

Automatic differentiation of Distributions using Tracker, Zygote, ForwardDiff and ReverseDiff
MIT License
150 stars 30 forks source link

MappedDistribution #12

Open mohamed82008 opened 5 years ago

mohamed82008 commented 5 years ago

Sometimes it is useful to be able to define a multivariate distribution on iid variables by generating distributions on the fly which use different distribution parameters each variable according to a certain rule/function. Defining an efficient logpdf and adjoint can give significant computational savings. This is similar to the Soss For combinator.

cscherrer commented 5 years ago

Here's my current setup:

struct For{F,T,D,X} 
    f :: F  
    θ :: T
end

where...

Some example use cases:

# T = NTuple{2,Int}
x ~ For(10,3) do i,j
    Bernoulli(j/i)
end
# T = Base.Generator{Base.OneTo{Int64},Base.var"#174#175"{Array{Float64,2}}}
y ~ For(eachrow(X)) do xrow
    Normal(xrow' * β, 1)
end

We'll have different methods for rand, logpdf, etc based mostly on T.

Also, I currently have the following restrictions:

  1. D is consistent across indices
  2. support(d::D) is consistent across indices

Currently this targets "array-like" results, but in principle T can be anything, for example an iterator or Real (for function spaces, GPs, etc).

mohamed82008 commented 5 years ago

I don't think we need a restriction on D being the same. The logpdf can be something like this:

function logpdf(dist::For, x::AbstractArray)
    @assert size(dist.θ) == size(x)
    return sum(1:length(dist.θ)) do i
        logpdf(dist.f(dist.θ[i]), x[i])
    end
end
rand(dist::For) = rand.(dist.f.(dist.θ))

Whether f returns the same distribution or not, this should be inferrable by the Julia compiler.

mohamed82008 commented 5 years ago
eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ))
mohamed82008 commented 5 years ago

Note that the above is a dynamically sized distribution. We can also get free specialization and inlining for small, fixed-size distributions when using θ::StaticArray.

mohamed82008 commented 5 years ago

I think for Tracker

sum(logpdf.(dist.f.(dist.θ), x))

will be faster than

sum(1:length(dist.θ)) do i
    logpdf(dist.f(dist.θ[i]), x[i])
end

So if either θ or x is a TrackedArray, all intermediates will also be TrackedArrays not TrackedReals.

cscherrer commented 5 years ago

I don't think we need a restriction on D being the same.

The most obvious reason for this is type stability, though there may be ways around that. In addition, the vast majority of models will satisfy this anyway, and it often opens up opportunities for optimization. For example, in cases where d.f maps to continuous distributions, how can we determine the bijection to ℝⁿ? Parameterizing by D makes this trivial.

One thing I've found a bit tricky is make useful type information available without much computational cost. Unfortunately in Julia, we can't just ask a function about its codomain, so instantiating a For requires some computation in order to determine the types. To this point, I've been trying to make construction cheap by assuming distributions and supports are consistent, and just computing them for a single index at construction time. Your eltype suggestion,

eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ))

is appealing, but would require O(n) instantiation cost.

Above I suggested For might also be useful for building distributions over function spaces. I may disagree with myself on this point, because it drops the conditional independence assumption of other For instances, and would require adding some way to specify covariance.

Finally, we had some recent discussion on Discourse about the best approach for parallelism, which will be important for many cases.

cscherrer commented 5 years ago

Cleaning this up a bit in Soss, here's the current state: https://github.com/cscherrer/Soss.jl/blob/dev/src/for.jl Should be able to get a PR submitted today.

There's also iid, which is like For but without the distributional dependence on indices. I have a curried form, which I usually use like this:

x ~ Normal() |> iid(N)
mohamed82008 commented 4 years ago

Thanks for the PR @cscherrer and sorry for the late review; I was busy the last few weeks. I will review your PR asap.