mjskay / ggdist

Visualizations of distributions and uncertainty
https://mjskay.github.io/ggdist/
GNU General Public License v3.0
845 stars 26 forks source link

Support distributional package #14

Closed mjskay closed 3 years ago

mjskay commented 4 years ago

Allow dist vectors in the dist aesthetic. See https://github.com/mitchelloharawild/distributional/issues/24

mjskay commented 4 years ago

@mitchelloharawild this should now be working with all stat_dist_... geoms on the dev branch of {ggdist}. Some examples:

dist_df = tribble(
    ~group, ~subgroup, ~dist,
    "a",       "h",     dist_normal(5, 1),
    "b",       "h",     dist_normal(7, 1.5),
    "c",       "h",     dist_normal(8, 1),
    "c",       "i",     dist_normal(9, 1),
    "c",       "j",     dist_normal(7, 1)
)
dist_df %>%
  ggplot(aes(x = group, dist = dist, fill = subgroup)) +
  stat_dist_eye(position = "dodge")

image

I'm particularly fond of how the syntax works out with dist_xxx() embedded in aes():

data.frame(alpha = seq(5, 100, length.out = 10)) %>%
  ggplot(aes(y = "", dist = dist_beta(alpha, 10), color = alpha)) +
  stat_dist_slab(fill = NA) +
  coord_cartesian(expand = FALSE) +
  scale_color_viridis_c() +
  ggtitle("Beta(alpha,10) distribution")

image

plus the obligatory lineribbon:

m_mpg = lm(mpg ~ hp * cyl, data = mtcars)

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  broom::augment(m_mpg, newdata = .) %>%
  ggplot(aes(x = hp, fill = ordered(cyl), color = ordered(cyl))) +
  stat_dist_lineribbon(
    aes(dist = dist_normal(.fitted, .se.fit)), 
    alpha = 1/4
  ) +
  geom_point(aes(y = mpg), data = mtcars) +
  scale_fill_brewer(palette = "Set2") +
  scale_color_brewer(palette = "Dark2")

image

This is really nice I think! Cleans up the syntax very well.

mitchelloharawild commented 4 years ago

:heart_eyes: Very excited to see this in action!

mitchelloharawild commented 4 years ago

Is it possible to add a legend for the ribbon shade? Or is the shade currently handled with overlapping transparent ribbons? This is how distributional handles the lineribbon ribbon geom (adding a line is with geom_line())

library(distributional)
library(modelr)
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(ggplot2)
m_mpg = lm(mpg ~ hp * cyl, data = mtcars)

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101), level = c(50, 80, 95)) %>%
  broom::augment(m_mpg, newdata = .) %>%
  mutate(
    dist = dist_normal(.fitted, .se.fit),
    hilo = hilo(dist, level)
  ) %>% 
  ggplot(aes(x = hp, fill = ordered(cyl), color = ordered(cyl))) +
  geom_hilo_ribbon(aes(hilo = hilo)) +
  geom_point(aes(y = mpg), data = mtcars) +
  scale_fill_brewer(palette = "Set2") +
  scale_color_brewer(palette = "Dark2")

Created on 2020-06-11 by the reprex package (v0.3.0)

mjskay commented 4 years ago

Hmm, you are doing that by mapping level onto an aesthetic that manipulates the lightness of the fill color or something like that? Interesting idea, could add something like that to lineribbon.

Currently lineribbon's default is to map level onto the fill color itself and then leave setting the fill scale up to the user. The example I gave above overrides that default and then uses alpha so the ribbons in each group overplot. The default lineribbon approach was more designed to do something like this:

mtcars %>%
  group_by(cyl) %>%
  data_grid(hp = seq_range(hp, n = 101)) %>%
  augment(m_mpg, newdata = .) %>%
  ggplot(aes(x = hp, color = ordered(cyl))) +
  stat_dist_lineribbon(aes(dist = dist_normal(.fitted, .se.fit))) +
  geom_point(aes(y = mpg), data = mtcars) +
  scale_fill_grey(start = .9, end = .7) +
  scale_color_brewer(palette = "Dark2")

image

Admittedly legends are a bit of blind spot for me --- I basically never use them as I prefer direct labelling.

mitchelloharawild commented 4 years ago

Hmm, you are doing that by mapping level onto an aesthetic that manipulates the lightness of the fill color or something like that? Interesting idea, could add something like that to lineribbon.

Yes, it converts the RGB to HSL and changes the luminance according to the aesthetic (https://github.com/mitchelloharawild/distributional/blob/48a17b9916b5caaf069b4533a63851a3ea79bb35/R/geom_hilo.R#L171-L177) It can be tricky to choose the right range for the luminance, but some trial and error gave me seq(90 - pmin((n_prob - 1)*10, 30), 90). This can probably be improved by someone with a better understanding of colour theory.

This luminance change currently occurs in the geom (https://github.com/mitchelloharawild/distributional/blob/48a17b9916b5caaf069b4533a63851a3ea79bb35/R/geom_hilo.R#L148). Ideally it would be possible to to provide fill_luminance and colour_luminance aesthetics which modify the fill and colour for any geom, but I think this would need to happen in ggplot2.

mitchelloharawild commented 4 years ago

There's also a feature of the guide (which is probably poor practice, but I don't know a better way) which changes the level aesthetic between discrete and continuous legends (https://github.com/mitchelloharawild/distributional/blob/48a17b9916b5caaf069b4533a63851a3ea79bb35/R/scale-level.R#L68-L120).

That way if you set level = 50:99 it doesn't flood your screen with the guide legend and instead converts to a guide colourbar.

The correct approach for this is probably to default to a colourbar as level is numeric, but I think of the level as being discrete (ordered) or continuous depending on the number of levels that are shown.

mjskay commented 4 years ago

Nice!

Ideally it would be possible to to provide fill_luminance and colour_luminance aesthetics which modify the fill and colour for any geom, but I think this would need to happen in ggplot2.

Yeah, some features to better support multivariate color scales in base ggplot2 would be cool. Barring that, creating a custom scale for luminance seems a pretty reasonable workaround.

mitchelloharawild commented 2 years ago

Any ideas on how multivariate distributions from {distributional} could be plotted with {ggdist}? This need has come about when rewriting plotting functions in {fabletools} to use {ggdist}.

Perhaps another grouping stat needs to be created for each dimension of the distribution?

library(ggplot2)
library(distributional)
library(ggdist)

library(fable)
#> Loading required package: fabletools

fc <- as_tsibble(cbind(mdeaths, fdeaths), pivot_longer = FALSE) %>%
  model(VAR(vars(mdeaths, fdeaths) ~ AR(3))) %>% 
  forecast()

fc
#> # A fable: 24 x 5 [1M]
#> # Key:     .model [1]
#>    .model                        index .distribution .mean_mdeaths .mean_fdeaths
#>    <chr>                         <mth>        <dist>         <dbl>         <dbl>
#>  1 VAR(vars(mdeaths, fdeaths… 1980 Jan        MVN[2]         1486.          575.
#>  2 VAR(vars(mdeaths, fdeaths… 1980 Feb        MVN[2]         1445.          558.
#>  3 VAR(vars(mdeaths, fdeaths… 1980 Mar        MVN[2]         1369.          528.
#>  4 VAR(vars(mdeaths, fdeaths… 1980 Apr        MVN[2]         1340.          505.
#>  5 VAR(vars(mdeaths, fdeaths… 1980 May        MVN[2]         1327.          497.
#>  6 VAR(vars(mdeaths, fdeaths… 1980 Jun        MVN[2]         1349.          505.
#>  7 VAR(vars(mdeaths, fdeaths… 1980 Jul        MVN[2]         1395.          522.
#>  8 VAR(vars(mdeaths, fdeaths… 1980 Aug        MVN[2]         1442.          540.
#>  9 VAR(vars(mdeaths, fdeaths… 1980 Sep        MVN[2]         1477.          554.
#> 10 VAR(vars(mdeaths, fdeaths… 1980 Oct        MVN[2]         1495.          561.
#> # … with 14 more rows

fc %>% 
  ggplot(aes(x = as.Date(index), dist = .distribution, fill_ramp = stat(level))) + 
  stat_dist_lineribbon(fill = "blue", .width = c(.8, .95), point_interval = mean_qi)
#> Warning: Duplicated aesthetics after name standardisation: point_interval

fc %>% 
  autoplot()

Created on 2021-11-19 by the reprex package (v2.0.0)

mjskay commented 2 years ago

Hmm, very good question. You can finagle it by manually pulling out the marginal distributions...

library(ggplot2)
library(distributional)
library(ggdist)
library(fable)
#> Loading required package: fabletools

fc <- as_tsibble(cbind(mdeaths, fdeaths), pivot_longer = FALSE) %>%
  model(VAR(vars(mdeaths, fdeaths) ~ AR(3))) %>% 
  forecast()

fc
#> # A fable: 24 x 4 [1M]
#> # Key:     .model [1]
#>    .model                      index .distribution .mean[,"mdeaths~ [,"fdeaths"]
#>    <chr>                       <mth>        <dist>            <dbl>        <dbl>
#>  1 VAR(vars(mdeaths, fdeat~ 1980 Jan        MVN[2]            1486.         575.
#>  2 VAR(vars(mdeaths, fdeat~ 1980 Feb        MVN[2]            1445.         558.
#>  3 VAR(vars(mdeaths, fdeat~ 1980 Mar        MVN[2]            1369.         528.
#>  4 VAR(vars(mdeaths, fdeat~ 1980 Apr        MVN[2]            1340.         505.
#>  5 VAR(vars(mdeaths, fdeat~ 1980 May        MVN[2]            1327.         497.
#>  6 VAR(vars(mdeaths, fdeat~ 1980 Jun        MVN[2]            1349.         505.
#>  7 VAR(vars(mdeaths, fdeat~ 1980 Jul        MVN[2]            1395.         522.
#>  8 VAR(vars(mdeaths, fdeat~ 1980 Aug        MVN[2]            1442.         540.
#>  9 VAR(vars(mdeaths, fdeat~ 1980 Sep        MVN[2]            1477.         554.
#> 10 VAR(vars(mdeaths, fdeat~ 1980 Oct        MVN[2]            1495.         561.
#> # ... with 14 more rows

fc_with_margins <- dplyr::bind_cols(
  tibble::as_tibble(vctrs::vec_rep(fc, times = ncol(mean(fc$.distribution)))),
  tibble::tibble(
    .variable = rep(colnames(mean(fc$.distribution)), each = nrow(fc)),
    .marginal_dist = dist_normal(c(mean(fc$.distribution)), c(sqrt(variance(fc$.distribution))))
  )
)

fc_with_margins
#> # A tibble: 48 x 6
#>    .model            index .distribution .mean[,"mdeaths~ [,"fdeaths"] .variable
#>    <chr>             <mth>        <dist>            <dbl>        <dbl> <chr>    
#>  1 VAR(vars(mdea~ 1980 Jan        MVN[2]            1486.         575. mdeaths  
#>  2 VAR(vars(mdea~ 1980 Feb        MVN[2]            1445.         558. mdeaths  
#>  3 VAR(vars(mdea~ 1980 Mar        MVN[2]            1369.         528. mdeaths  
#>  4 VAR(vars(mdea~ 1980 Apr        MVN[2]            1340.         505. mdeaths  
#>  5 VAR(vars(mdea~ 1980 May        MVN[2]            1327.         497. mdeaths  
#>  6 VAR(vars(mdea~ 1980 Jun        MVN[2]            1349.         505. mdeaths  
#>  7 VAR(vars(mdea~ 1980 Jul        MVN[2]            1395.         522. mdeaths  
#>  8 VAR(vars(mdea~ 1980 Aug        MVN[2]            1442.         540. mdeaths  
#>  9 VAR(vars(mdea~ 1980 Sep        MVN[2]            1477.         554. mdeaths  
#> 10 VAR(vars(mdea~ 1980 Oct        MVN[2]            1495.         561. mdeaths  
#> # ... with 38 more rows, and 1 more variable: .marginal_dist <dist>

fc_with_margins %>%
  ggplot(aes(x = as.Date(index), dist = .marginal_dist, fill_ramp = stat(level))) + 
  stat_dist_lineribbon(fill = "blue", .width = c(.8, .95), point_interval = mean_qi) +
  facet_wrap(~ .variable, scales = "free")

Created on 2021-11-18 by the reprex package (v2.0.1)

But this is obviously not a great solution.

One solution might be to create a function for "flattening" or "unnesting" columns with multivariate distributions into two columns, one with the name of the margin and one with the marginal distribution (basically a generic version of the operation above). I think that would provide the needed flexibility to either facet on the marginal variable or map it onto the group aesthetic to create multiple ribbons within one facet. Dunno if something like that should live in ggdist or in distributional; I'm fine with either (and I'm not sure what your plans are for how the multivariate API will look in distributional).

mitchelloharawild commented 2 years ago

Yes - this is similar to how fable currently handles multivariate distributions. However as fable works on plotting <hilo> intervals rather than distributions, it is easier to compute the intervals of the multivariate distribution and then plot each interval separately. While I agree that having a method for marginalising multivariate distributions would be neat, I think that requiring it for plotting limits potential future graphics options.

With the rework of distributional methods (https://github.com/mitchelloharawild/distributional/issues/52#issuecomment-939891809), multivariate distributions now give matrix outputs - each variate as a (named) column. I think it would make most sense for {ggdist} to take this output and rearrange it into a long form - creating a new group from the column names. I'm not sure how this would look internally for {ggdist}, but I imagine that it could be placed in the Stat calculations. It might still be tricky in {fabletools}, as I'm not sure if after_stat() variables can be used for facetting - I think I'd still have to workaround it somehow.

mjskay commented 2 years ago

With the rework of distributional methods (mitchelloharawild/distributional#52 (comment)), multivariate distributions now give matrix outputs - each variate as a (named) column. I think it would make most sense for {ggdist} to take this output and rearrange it into a long form - creating a new group from the column names. I'm not sure how this would look internally for {ggdist}, but I imagine that it could be placed in the Stat calculations.

Hmm, this could probably happen somewhere in the point_interval() family. Then it would be supported either in the stats or as a pre-processing step with point_interval (after which you could use the geom form instead of the stat --- which would allow you to do faceting on the variable without using after_stat).