Open mohamed82008 opened 5 years ago
Here's my current setup:
struct For{F,T,D,X}
f :: F
θ :: T
end
where...
F
and T
are specified in the structD
is the distribution returnedX
is the eltype
of that distribution (unfortunately, often not available directly from D
)Some example use cases:
# T = NTuple{2,Int}
x ~ For(10,3) do i,j
Bernoulli(j/i)
end
# T = Base.Generator{Base.OneTo{Int64},Base.var"#174#175"{Array{Float64,2}}}
y ~ For(eachrow(X)) do xrow
Normal(xrow' * β, 1)
end
We'll have different methods for rand
, logpdf
, etc based mostly on T
.
Also, I currently have the following restrictions:
D
is consistent across indicessupport(d::D)
is consistent across indicesCurrently this targets "array-like" results, but in principle T
can be anything, for example an iterator or Real
(for function spaces, GPs, etc).
I don't think we need a restriction on D
being the same. The logpdf
can be something like this:
function logpdf(dist::For, x::AbstractArray)
@assert size(dist.θ) == size(x)
return sum(1:length(dist.θ)) do i
logpdf(dist.f(dist.θ[i]), x[i])
end
end
rand(dist::For) = rand.(dist.f.(dist.θ))
Whether f
returns the same distribution or not, this should be inferrable by the Julia compiler.
eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ))
Note that the above is a dynamically sized distribution. We can also get free specialization and inlining for small, fixed-size distributions when using θ::StaticArray
.
I think for Tracker
sum(logpdf.(dist.f.(dist.θ), x))
will be faster than
sum(1:length(dist.θ)) do i
logpdf(dist.f(dist.θ[i]), x[i])
end
So if either θ
or x
is a TrackedArray
, all intermediates will also be TrackedArray
s not TrackedReal
s.
I don't think we need a restriction on D being the same.
The most obvious reason for this is type stability, though there may be ways around that. In addition, the vast majority of models will satisfy this anyway, and it often opens up opportunities for optimization. For example, in cases where d.f
maps to continuous distributions, how can we determine the bijection to ℝⁿ? Parameterizing by D
makes this trivial.
One thing I've found a bit tricky is make useful type information available without much computational cost. Unfortunately in Julia, we can't just ask a function about its codomain, so instantiating a For
requires some computation in order to determine the types. To this point, I've been trying to make construction cheap by assuming distributions and supports are consistent, and just computing them for a single index at construction time. Your eltype
suggestion,
eltype(dist::For) = mapreduce(i -> eltype(dist.f(dist.θ[i])), promote_type, 1:length(dist.θ))
is appealing, but would require O(n) instantiation cost.
Above I suggested For
might also be useful for building distributions over function spaces. I may disagree with myself on this point, because it drops the conditional independence assumption of other For
instances, and would require adding some way to specify covariance.
Finally, we had some recent discussion on Discourse about the best approach for parallelism, which will be important for many cases.
Cleaning this up a bit in Soss, here's the current state: https://github.com/cscherrer/Soss.jl/blob/dev/src/for.jl Should be able to get a PR submitted today.
There's also iid
, which is like For
but without the distributional dependence on indices. I have a curried form, which I usually use like this:
x ~ Normal() |> iid(N)
Thanks for the PR @cscherrer and sorry for the late review; I was busy the last few weeks. I will review your PR asap.
Sometimes it is useful to be able to define a multivariate distribution on iid variables by generating distributions on the fly which use different distribution parameters each variable according to a certain rule/function. Defining an efficient
logpdf
and adjoint can give significant computational savings. This is similar to the SossFor
combinator.