TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.02k stars 219 forks source link

Change broadcasting for ~ macro #476

Closed trappmartin closed 4 years ago

trappmartin commented 6 years ago

Currently the ~ macro automatically uses broadcasting if the left-hand side is a vector and the right-hand side a univariate distribution. It would be more convenient for testing and for the user to change it to an explicit syntax, i.e..~, which would also better match the standard Julia syntax.

@yebai What happens in the case of multivariate distributions?

yebai commented 6 years ago

I think ~ will check whether the dimensionality of left-side is a matrix or a vector of vectors. If so, it will treat left-side as a vector of draws from multivariate distributions.

willtebbutt commented 5 years ago

@trappmartin I agree with your point regarding Julia-ness, but we're quite closely tied in to the Distributions.jl conventions, and they make very specific assumptions regarding what multiple samples from a distribution look like depending upon whether it's a UnivariateDistribution, MultivarateDistribution, or a MatrixDistribution (see the Distributions.jl documentation for more info). Consequently, I think it actually makes sense not to follows the standard Julia conventions here, but rather to vectorise in the traditional sense and follow the Distributions.jl conventions.

To provide a concrete example of where we would get in trouble using a .~ notation, the Distributions convention for multiple samples from a MultivariateDistribution is to produce a matrix whose columns are observations. What, then, does the broadcasting .~ notation mean? One the one hand our intent would probably be to broadcast observation-wise, but broadcasting is already defined for AbstractMatrix in an element-wise fashion, so it would necessarily involve introducing some non-standard behaviour, which I believe would be a mistake.

trappmartin commented 5 years ago

Thanks, @willtebbutt! The convention of the Distributions package is quite unfortunate. I understand that with this convention, we cannot generically broadcast.

Let's keep this issue open and see if we can find a more transparent solution to this. I feel that the implicit broadcasting solution we currently have is sub-optimal.

willtebbutt commented 5 years ago

Agreed regarding the convention. FWIW, I had the same problem in Stheno.jl with regards to data sets. The inputs to a GP (or any other model for that matter) might be a vector of scalars, a matrix of for which each row or column should be interpreted as an observation, etc.

My solution is to require that all data sets subtype AbstractVector and define indexing behaviour. This way the semantics are unambiguous, but you can still have the data stored in whatever manner you like under the hood, and just use multiple dispatch in the usual way to specialise behaviour for different storage types. Case in point, I have a type called ColsAreObs (read: Columns are observations):

struct ColsAreObs{T, TX<:AbstractMatrix{T}} <: AbstractVector{Vector{T}}
    X::TX
    ColsAreObs(X::TX) where {T, TX<:AbstractMatrix{T}} = new{T, TX}(X)
end

designed exactly to handle the situation described above, with the interpretation is that the nth column of X is the nth observation. I then specialise a number of implementations for this type.

With this design you lose some conventions that people are used to seeing, but you gain consistency in terms of the interface for all of the data types without sacrificing any performance (provided that you specialise correctly)

yebai commented 5 years ago

Related discussion from Soss.jl:

https://github.com/cscherrer/Soss.jl/issues/14

xukai92 commented 5 years ago

Another reason to stick to Distributions' convention is that the vectorised version of logpdf is faster than that of the broadcasting version for some distributions, e.g. MvNormal

julia> using Distributions, BenchmarkTools

julia> d = MvNormal(zeros(2), ones(2))
DiagNormal(
dim: 2
μ: [0.0, 0.0]
Σ: [1.0 0.0; 0.0 1.0]
)

julia> rvs = rand(d, 1000);

julia> @benchmark sum(logpdf(d, rvs))
BenchmarkTools.Trial: 
  memory estimate:  23.70 KiB
  allocs estimate:  3
  --------------
  minimum time:     4.171 μs (0.00% GC)
  median time:      5.322 μs (0.00% GC)
  mean time:        7.590 μs (22.46% GC)
  maximum time:     5.710 ms (99.87% GC)
  --------------
  samples:          10000
  evals/sample:     7

julia> rvs_col = [rvs[:,i] for i = 1:1000];

julia> @benchmark sum(logpdf.(Ref(d), rvs_col))
BenchmarkTools.Trial: 
  memory estimate:  102.63 KiB
  allocs estimate:  1022
  --------------
  minimum time:     53.432 μs (0.00% GC)
  median time:      57.578 μs (0.00% GC)
  mean time:        70.490 μs (9.19% GC)
  maximum time:     34.866 ms (99.74% GC)
  --------------
  samples:          10000
  evals/sample:     1

Ideally we'd like to call these vectorised version of logpdf in Turing instead of the broadcast version.

Although this is not very consistent in Distributions are this vectorized version for univariate version is deprecated. If you try this for Normal you will see "Warning: logpdf(d::UnivariateDistribution, X::AbstractArray) is deprecated, use logpdf.(d, X) instead."

mkborregaard commented 5 years ago

Is the Distributions convention set in stone?

willtebbutt commented 5 years ago

Is the Distributions convention set in stone?

I'm not sure what the Distributions maintainers' position is on this. FWIW my personal feeling is that we should consider adopting the convention I discussed above in Turing for consistency's sake. If we do that we could (either implicitly or explicitly) require that

  1. The type of the lhs of a ~ is distributed over by the rhs, and
  2. The lhs of a .~ is compatible with broadcasting (whatever that looks like)

edit: we should find out what the Distributions position is on changing their conventions for multiple sampling.

mohamed82008 commented 5 years ago

So here are some thoughts I have on this issue. I like the .~ syntax and I think we can accommodate both the normal broadcast semantics, the Distributions.jl convention and the distribution map semantics from #819 with a small price to pay. The key aspect here is to differentiate between the LHS-aware .~ and the LHS-unaware ~. The former tries to use "broadcast" in the most do-what-I-mean (DWIM) fashion. Turns out we can set a decent convention that supports all the possible options I can think of, and is still somewhat in the DWIM spirit.

Let's take the ~ case first. I propose we have the following:

  1. ~ doesn't check the LHS, it only assigns to it or calls setindex! if the LHS is an indexed array. If setindex! fails here, it is on the user because they have mismatched LHS and RHS.
  2. ~ lowers to the 4-argument assume(spl, dist, vn, vi). dist is the RHS.
  3. If the RHS is a distribution, we sample a value from that distribution.
  4. If the RHS is an array of distributions, we sample a value from each element of the distribution array and return an array of values having exactly the same shape. The multivariate distribution case is no exception in this case, a vector of multivariate distributions gives us a vector of vectors, not a matrix. Since we are assuming we don't know about the LHS, there is no reason to treat multivariate distributions differently in this case IMO.

The above is possible to achieve only with dispatch on the dist argument of the 4-argument assume.

Now for the .~, I propose the following:

  1. .~ checks the LHS, and mutates it inplace, trying to use broadcast semantics in the most DWIM fashion. This means the LHS and RHS can of different shapes, and it should still work.
  2. .~ lowers to a 5-argument assume(spl, rhs, vn, lhs, vi). The lhs can be passed as a view and modified inside the function. The view will only be for the right-most index in the case of x[...][...].
  3. If the RHS is a multivariate distribution or a vector of multivariate distributions, and the LHS is a matrix of non-vectors, we use the Distributions.jl multi-column sampling. So each column of the LHS gets a vector from its corresponding distribution. In the case where the LHS and RHS don't match in size, we need to "broadcast" in the Julia sense to stick to the spirit of .~. Lucky for us, we only need to handle the rhs::Vector{<:MultivariateDistribution} so the size broadcast here is a simple repeat, which will obviously be done in a lazy fashion. Even though we will support rhs::Vector{<:MultivariateDistributions} because it is nice to have, I think the main use of this syntax will be when rhs::MultivariateDistribution since that will actually give a performance boost when we use Distributions.jl's multi-column sampling. In the rhs::Vector{<:MultivariateDistributions} case, we will just use a loop underneath.
  4. For any other case, we just use normal Julia broadcasting, assigning a different value to each element in the LHS array. The size broadcasting can also be taken care of by Julia. This will cover the lhs::Matrix{<:Vector} and rhs::MultivariateDistribution and the likes quite nicely. Note that if we tried to use Julia broadcast semantics in the previous point, we would have failed assuming eltype(lhs) !== Any and isn't any supertype of both scalars and vectors which is not very typical. So for the most part, we have a DWIM approach here except for the lhs::Array{<:Union{Real, Vector}} case and its likes which I think is a small price to pay. We can also check for that case explicitly and give a warning or error to the user.

Implementing the above would again require only careful dispatching. That said, doing it for all the inference algorithms may be annoying since we have quite a few. So I still need to figure out a nicer way to implement and maintain this easily.

FWIW, once we have the .~ and 5-argument assume standardized, we can use it to do neat GPU and CPU parallelism tricks using nothing but dispatch. @trappmartin this is relevant to your discussion in #830. Julia has a nice parallelism-in-type approach that we can exploit here instead of defining our own loop constructs. The loop construct can still be useful though in other more custom parallelism use-cases, so I am not against that either.

So far we also don't have a multi-threading (shared memory parallelism) array type like DArray for distributed parallelism and GPUArray for GPU parallelism AFAIK. Even SharedArray is actually just for fancy distributed memory parallelism when the different cores are on a single machine IIUC, so it is not true shared memory parallelism with OpenMP-style threads. I remember seeing a comment by @ChrisRackauckas a while back that he wanted to implement this multi-threading array type, not sure if he made it or not. KissThreading.jl might be a good place to have such an array type, so I can start implementing one if it doesn't already exist somewhere. Anyways, we can worry about this after we have the .~ syntax nicely supported.

Comments?

ChrisRackauckas commented 5 years ago

So far we also don't have a multi-threading (shared memory parallelism) array type like DArray for distributed parallelism and GPUArray for GPU parallelism AFAIK. Even SharedArray is actually just for fancy distributed memory parallelism when the different cores are on a single machine IIUC, so it is not true shared memory parallelism with OpenMP-style threads. I remember seeing a comment by @ChrisRackauckas a while back that he wanted to implement this multi-threading array type, not sure if he made it or not. KissThreading.jl might be a good place to have such an array type, so I can start implementing one if it doesn't already exist somewhere. Anyways, we can worry about this after we have the .~ syntax nicely supported.

I had some prototypes back in v0.6, then they got massacred from the v1.0 broadcast changes so I'd have to start from scratch and never picked it up again. So yes, please implement them in KissThreading.jl and I'll be forever grateful.

mohamed82008 commented 5 years ago

I had some prototypes back in v0.6, then they got massacred from the v1.0 broadcast changes so I'd have to start from scratch and never picked it up again. So yes, please implement them in KissThreading.jl and I'll be forever grateful.

Cool, I will try to get some basic implementation done, at least to forward the map family to their tmap counterparts for now. Even broadcasting may probably be made simple with some help from LazyArrays.BroadcastArray but I never played with the broadcast machinery of Julia v1 before so there will be a learning curve.

I also forgot to mention above that a similar approach will need to be taken for observe statements but without the assignment part. This is the relevant one for @trappmartin's use case.

mohamed82008 commented 4 years ago

Closed via #965.