JuliaStats / Distributions.jl

A Julia package for probability distributions and associated functions.
Other
1.1k stars 413 forks source link

Adding a NamedTupleVariate #1762

Open sethaxen opened 1 year ago

sethaxen commented 1 year ago

It could be useful to add a NamedTupleVariate with necessary defaults to this package. A concrete use case is that one may want a product distribution with easy access to the individual components.

Here's a barebones implementation:

using Distributions, Random

abstract type NamedTupleVariate <: VariateForm end

struct NamedTupleProductDistribution{Tnames,Tdists,eltypes,S<:ValueSupport} <:
       Distribution{NamedTupleVariate,S}
    dists::NamedTuple{Tnames,Tdists}
end
function NamedTupleProductDistribution(
    dists::NamedTuple{K,V}
) where {K,V<:Tuple{Vararg{Distribution}}}
    eltypes = Tuple{map(eltype, values(dists))...}
    # would be better to allow mixed ValueSupports here
    vs = Distributions._product_valuesupport(dists)
    return NamedTupleProductDistribution{K,V,eltypes,vs}(dists)
end

function Distributions.product_distribution(
    dists::NamedTuple{K,V}
) where {K,V<:Tuple{Vararg{Distribution}}}
    return NamedTupleProductDistribution(dists)
end

function Distributions.eltype(::NamedTupleProductDistribution{K,<:Any,V}) where {K,V}
    return NamedTuple{K,V}
end

function Distributions.insupport(
    dist::NamedTupleProductDistribution{K}, x::NamedTuple{K}
) where {K}
    return all(Base.splat(insupport), zip(dist.dists, x))
end

function Distributions.pdf(
    dist::NamedTupleProductDistribution{K}, x::NamedTuple{K}
) where {K}
    return exp(logpdf(dist, x))
end
function Distributions.logpdf(
    dist::NamedTupleProductDistribution{K}, x::NamedTuple{K}
) where {K}
    return mapreduce(logpdf, +, dist.dists, x)
end

function Distributions.rand(
    rng::AbstractRNG, dist::NamedTupleProductDistribution{K}
) where {K}
    return NamedTuple{K}(map(Base.Fix1(rand, rng), dist.dists))
end

and example usage:

julia> dist = product_distribution((; x=Normal(), y=Dirichlet(3, 1)));

julia> x = rand(dist)
(x = 0.35877545055901744, y = [0.5772411031132897, 0.2856859832919011, 0.13707291359480941])

julia> insupport(dist, x)
true

julia> pdf(dist, x)
0.7481503903444227
aplavin commented 1 year ago

See https://github.com/invenia/KeyedDistributions.jl for an array-based take on this: it builds on keyed arrays. But yeah Base NamedTuples would also be useful.

Red-Portal commented 9 months ago

Hi I also think this would be a really useful idea. Any plans for this to make it forward?

sethaxen commented 9 months ago

Hi I also think this would be a really useful idea. Any plans for this to make it forward?

I had a local draft I just pushed to #1803. Still needs tests and some clarification about eltypes, but it works.