mitchelloharawild / distributional

Vectorised distributions for R
https://pkg.mitchelloharawild.com/distributional
GNU General Public License v3.0
94 stars 15 forks source link

family() to optionally return the types of distributions in the mixture #110

Open statasaurus opened 1 month ago

statasaurus commented 1 month ago

At the moment, there really isn't an easy way of looking at the contents of mixture distributions. I am trying to determine if all the elements in a mixture are normal to see if I can just use conjugacy to calculate the posterior. But, family() just returns the that it is a mixture. So it would be good if there was maybe an additional parameter or something that you could add to family() to look at the distributions making up the mixture.


Brief description of the problem

x <- dist_mixture(dist_normal(0, 1), dist_normal(5, 2), weights = c(0.3, 0.7))
family(x)
foo <- vec_data(x)
foo[[1]]$dist |> 
  map_lgl(\(x) "dist_normal" %in% class(x))
mitchelloharawild commented 1 month ago

This seems a bit tricky to design, but a similar problem would apply to other distribution modifiers - dist_inflated(), dist_truncated(), and dist_transformed().

Here's a recursive approach:

library(distributional)
x <- dist_mixture(dist_normal(0, 1), dist_normal(5, 2), weights = c(0.3, 0.7))

get_base_families <- function(x) {
  fam <- family(x)
  is_modified <- family(x) %in% c("mixture", "transformed", "inflated", "truncated")
  if(any(is_modified)) {
    fam[is_modified] <- lapply(parameters(x[is_modified])$dist, function(dist) lapply(dist, get_base_families))
  }
  fam
}

get_base_families(x)
#> [[1]]
#> [[1]][[1]]
#> [1] "normal"
#> 
#> [[1]][[2]]
#> [1] "normal"

Created on 2024-05-15 with reprex v2.0.2

statasaurus commented 1 month ago

Something like this in the package would be super helpful. Thank you