mitchelloharawild / distributional

Vectorised distributions for R
https://pkg.mitchelloharawild.com/distributional
GNU General Public License v3.0
94 stars 15 forks source link

Mutating probabilities from cdf() in a data frame #114

Closed adhi-r closed 1 week ago

adhi-r commented 2 weeks ago

Hi, I'd like to use the distributional::cdf( ) function in a dplyr pipeline to mutate a column probabilities. I'd think this is possible since it's supposed to be vectorized. But I get this colnames error that i've never seen before.


Can't mutate probabilities from the cdf in a data frame/dplyr pipeline.

library(fpp3)

google_stock <- gafa_stock |>
  filter(Symbol == "GOOG", year(Date) >= 2015) |>
  mutate(day = row_number()) |>
  update_tsibble(index = day, regular = TRUE)
# Filter the year of interest
google_2015 <- google_stock |> filter(year(Date) == 2015)

g_fcasts <- google_2015 |>
  model(NAIVE(Close)) |>
  forecast(h = 10) 

g_fcasts |> 
  mutate(strike_price = 765,
         probability = distributional::cdf(Close, strike_price))

Error in `mutate()`:
ℹ In argument: `probability = distributional::cdf(Close, strike_price)`.
Caused by error in `FUN()`:
! attempt to set 'colnames' on an object with less than two dimensions
Backtrace:
  1. dplyr::mutate(...)
  9. distributional:::cdf.distribution(Close, strike_price)
 10. distributional:::dist_apply(x, cdf, q = q, ...)
 11. base::lapply(out, `colnames<-`, dn)
 12. base (local) FUN(X[[i]], ...)
 13. base::stop("attempt to set 'colnames' on an object with less than two dimensions")
adhi-r commented 2 weeks ago

@robjhyndman you might know how to do this?

mitchelloharawild commented 2 weeks ago

For this you would want to use a length 1 strike_price, although I see how this result could be surprising. We reached this behaviour after some discussion in #52, however this could/should probably be revisited for further improvement.

library(fpp3)
#> -- Attaching packages ---------------------------------------------- fpp3 0.5 --
#> v tibble      3.2.1     v tsibble     1.1.4
#> v dplyr       1.1.4     v tsibbledata 0.4.1
#> v tidyr       1.3.1     v feasts      0.3.2
#> v lubridate   1.9.3     v fable       0.3.4
#> v ggplot2     3.5.1     v fabletools  0.4.2
#> -- Conflicts ------------------------------------------------- fpp3_conflicts --
#> x lubridate::date()    masks base::date()
#> x dplyr::filter()      masks stats::filter()
#> x tsibble::intersect() masks base::intersect()
#> x tsibble::interval()  masks lubridate::interval()
#> x dplyr::lag()         masks stats::lag()
#> x tsibble::setdiff()   masks base::setdiff()
#> x tsibble::union()     masks base::union()

google_stock <- gafa_stock |>
  filter(Symbol == "GOOG", year(Date) >= 2015) |>
  mutate(day = row_number()) |>
  update_tsibble(index = day, regular = TRUE)
# Filter the year of interest
google_2015 <- google_stock |> filter(year(Date) == 2015)

g_fcasts <- google_2015 |>
  model(NAIVE(Close)) |>
  forecast(h = 10) 

g_fcasts |> 
  mutate(strike_price = 765,
         probability = distributional::cdf(Close, 765))
#> # A fable: 10 x 7 [1]
#> # Key:     Symbol, .model [1]
#>    Symbol .model         day        Close .mean strike_price probability
#>    <chr>  <chr>        <dbl>       <dist> <dbl>        <dbl>       <dbl>
#>  1 GOOG   NAIVE(Close)   253  N(759, 125)  759.          765       0.708
#>  2 GOOG   NAIVE(Close)   254  N(759, 250)  759.          765       0.651
#>  3 GOOG   NAIVE(Close)   255  N(759, 376)  759.          765       0.624
#>  4 GOOG   NAIVE(Close)   256  N(759, 501)  759.          765       0.608
#>  5 GOOG   NAIVE(Close)   257  N(759, 626)  759.          765       0.597
#>  6 GOOG   NAIVE(Close)   258  N(759, 751)  759.          765       0.588
#>  7 GOOG   NAIVE(Close)   259  N(759, 876)  759.          765       0.582
#>  8 GOOG   NAIVE(Close)   260 N(759, 1002)  759.          765       0.577
#>  9 GOOG   NAIVE(Close)   261 N(759, 1127)  759.          765       0.572
#> 10 GOOG   NAIVE(Close)   262 N(759, 1252)  759.          765       0.569

Created on 2024-06-26 with reprex v2.1.0

adhi-r commented 1 week ago

Thanks for the response Mitchell! A few hours ago I found that using ‘rowwise()’ actually solves this.

adhi-r commented 1 week ago

I should also mention that in my minimal example, I have only one strike price so yes you can just hardcode it in like that. In reality, I have many "strikes" i want to compute a cdf on, and i want to do it on many different models and forecasts. rowwise() allows it!

mitchelloharawild commented 1 week ago

Ah, I also remember from #52 that we had list type inputs to vectorise across values. For example:

library(fpp3)
#> -- Attaching packages ---------------------------------------------- fpp3 0.5 --
#> v tibble      3.2.1     v tsibble     1.1.4
#> v dplyr       1.1.4     v tsibbledata 0.4.1
#> v tidyr       1.3.1     v feasts      0.3.2
#> v lubridate   1.9.3     v fable       0.3.4
#> v ggplot2     3.5.1     v fabletools  0.4.2
#> -- Conflicts ------------------------------------------------- fpp3_conflicts --
#> x lubridate::date()    masks base::date()
#> x dplyr::filter()      masks stats::filter()
#> x tsibble::intersect() masks base::intersect()
#> x tsibble::interval()  masks lubridate::interval()
#> x dplyr::lag()         masks stats::lag()
#> x tsibble::setdiff()   masks base::setdiff()
#> x tsibble::union()     masks base::union()
google_stock <- gafa_stock |>
  filter(Symbol == "GOOG", year(Date) >= 2015) |>
  mutate(day = row_number()) |>
  update_tsibble(index = day, regular = TRUE)
# Filter the year of interest
google_2015 <- google_stock |> filter(year(Date) == 2015)

g_fcasts <- google_2015 |>
  model(NAIVE(Close)) |>
  forecast(h = 10) 

g_fcasts |> 
  mutate(strike_price = 765,
         probability = distributional::cdf(Close, tibble(strike_price)))
#> # A fable: 10 x 7 [1]
#> # Key:     Symbol, .model [1]
#>    Symbol .model      day        Close .mean strike_price probability$strike_p~1
#>    <chr>  <chr>     <dbl>       <dist> <dbl>        <dbl>                  <dbl>
#>  1 GOOG   NAIVE(Cl~   253  N(759, 125)  759.          765                  0.708
#>  2 GOOG   NAIVE(Cl~   254  N(759, 250)  759.          765                  0.651
#>  3 GOOG   NAIVE(Cl~   255  N(759, 376)  759.          765                  0.624
#>  4 GOOG   NAIVE(Cl~   256  N(759, 501)  759.          765                  0.608
#>  5 GOOG   NAIVE(Cl~   257  N(759, 626)  759.          765                  0.597
#>  6 GOOG   NAIVE(Cl~   258  N(759, 751)  759.          765                  0.588
#>  7 GOOG   NAIVE(Cl~   259  N(759, 876)  759.          765                  0.582
#>  8 GOOG   NAIVE(Cl~   260 N(759, 1002)  759.          765                  0.577
#>  9 GOOG   NAIVE(Cl~   261 N(759, 1127)  759.          765                  0.572
#> 10 GOOG   NAIVE(Cl~   262 N(759, 1252)  759.          765                  0.569
#> # i abbreviated name: 1: probability$strike_price

Created on 2024-06-26 with reprex v2.1.0