tidyverse / purrr

A functional programming toolkit for R
https://purrr.tidyverse.org/
Other
1.27k stars 271 forks source link

Calling non standard functions (dplyr::n()) within map2() and across(). #1101

Closed JustGitting closed 10 months ago

JustGitting commented 1 year ago

I want to programmatically summarize/pivot a data.frame and found a nice solution here:

How do I build a dplyr summarize statement programmatically? https://stackoverflow.com/questions/68853709/how-do-i-build-a-dplyr-summarize-statement-programmatically

A modified version of the example in the link is given below:

library(dplyr)
library(purrr)
library(stringr)

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('sum', 'mean')

pivot <- map2( value_fields, 
                fnc_str, 
                ~ iris %>%
                  group_by( Species ) %>% 
                  summarise( across(all_of(.x), 
                                    match.fun(.y), 
                                    .names = str_c("{.col}_", .y)), 
                             .groups = 'drop')
          )  %>%
            reduce(inner_join)

pivot %>% knitr::kable()

# |Species    | Sepal.Length_sum| Sepal.Width_mean|
# |:----------|----------------:|----------------:|
# |setosa     |            250.3|            3.428|
# |versicolor |            296.8|            2.770|
# |virginica  |            329.4|            2.974|

However, it does not work with the special dplyr function n().

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('sum', 'n')

pivot <- map2( value_fields, 
               fnc_str, 
               ~ iris %>%
                 group_by( Species ) %>% 
                 summarise( across(all_of(.x), 
                                   match.fun(.y), 
                                   .names = str_c("{.col}_", .y)), 
                            .groups = 'drop')
)  %>%
  reduce(inner_join)

# Error in `map2()`:
#   ℹ In index: 2.
# Caused by error in `summarise()`:
#   ℹ In argument: `across(all_of(.x), match.fun(.y), .names = str_c("{.col}_", .y))`.
# ℹ In group 1: `Species = setosa`.
# Caused by error in `across()`:
#   ! Can't compute column `Sepal.Width_n`.
# Caused by error:
# ! unused argument (Sepal.Width)
# Run `rlang::last_trace()` to see where the error occurred.

The problem is because how the function n() is called within across. I found a work-around for how to treat n() from the following stackoverflow post.

How to count rows by group with n() inside dplyr::across()? https://stackoverflow.com/questions/66161658/how-to-count-rows-by-group-with-n-inside-dplyracross

library(dplyr)

This example does not work:

df %>%
  group_by(group) %>% 
  summarise(across(value, list(sum = sum, count = n ))), .groups = 'drop')

The solution of calling ~n() instead of just n works.

df %>%
  group_by(group) %>% 
  summarise(across(value, list(sum = sum, count = ~ n() )), .groups = 'drop')

I applied this solution to the map2() example, but it fails.

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('n', 'n')

pivot <- map2( value_fields, 
                fnc_str, 
                ~ iris %>%
                  group_by( Species ) %>% 
                  summarise( across(all_of(.x), 
                                    ~match.fun(.y)(), 
                                    .names = str_c("{.col}_", .y)), 
                             .groups = 'drop')
          )  %>%
            reduce(inner_join)
# Error in `map2()`:
#   ℹ In index: 1.
# Caused by error in `summarise()`:
#   ℹ In argument: `across(...)`.
# ℹ In group 1: `Species = setosa`.
# Caused by error in `across()`:
#   ! Can't compute column `Sepal.Length_n`.
# Caused by error in `match.fun()`:
# ! the ... list contains fewer than 2 elements

After a lot of experimenting I found a work-around by wrapping the match.fun in a lambda.

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('n', 'n')

pivot <- map2( value_fields, 
      fnc_str, 
      ~ iris %>%
        group_by( Species ) %>% 
        summarise( across(all_of(.x), 
                          # ~match.fun(.y)(), 
                          function(.y) match.fun(.y)(),
                          .names = str_c("{.col}_", .y)), 
                   .groups = 'drop')
)  %>%
  reduce(inner_join)

pivot %>% knitr::kable()

# |Species    | Sepal.Length_n| Sepal.Width_n|
# |:----------|--------------:|-------------:|
# |setosa     |             50|            50|
# |versicolor |             50|            50|
# |virginica  |             50|            50|

However, it does not work with functions that take parameters because no values are passed to them.

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('n', 'sum')

pivot <- map2( value_fields, 
               fnc_str, 
               ~ iris %>%
                 group_by( Species ) %>% 
                 summarise( across(all_of(.x), 
                                   # ~match.fun(.y)(), 
                                   function(.y) match.fun(.y)(),
                                   .names = str_c("{.col}_", .y)), 
                            .groups = 'drop')
)  %>%
  reduce(inner_join)

pivot %>% knitr::kable()

# |Species    | Sepal.Length_n| Sepal.Width_sum|
# |:----------|--------------:|---------------:|
# |setosa     |             50|               0|
# |versicolor |             50|               0|
# |virginica  |             50|               0|

Next up, I tried to use an if() statement to call match.fun() in the correct way depending on the number of parameters the summarizing function takes, but I end up with errors.

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('n', 'sum')

pivot <- map2( value_fields, 
               fnc_str, 
               ~ iris %>%
                 group_by( Species ) %>% 
                 summarise( across(all_of(.x), 
                                   function(.y){ 
                                      if(length(as.list(args(.y))) == 1L){ # check number of arguments to function .
                                        match.fun(.y)() # no arguments
                                        }else{
                                          match.fun(.y) # takes the column values as before.
                                        } 
                                     },
                                   .names = str_c("{.col}_", .y) ), 
                            .groups = 'drop')
)  %>%
  reduce(inner_join)

# Error in `map2()`:
#  ℹ In index: 1.
# Caused by error in `summarise()`:
#   ℹ In argument: `across(...)`.
# ℹ In group 1: `Species = setosa`.
# Caused by error in `across()`:
#   ! Can't compute column `Sepal.Length_n`.
# Caused by error in `get()`:
# ! object 'Sepal.Length' of mode 'function' was not found
# Run `rlang::last_trace()` to see where the error occurred.

I'm out of ideas. How to allow map2() and across() to run "normal" functions and the special n() that does not take any arguments?

> sessionInfo()
R version 4.2.2 (2022-10-31 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19045)

Matrix products: default

other attached packages:
 [1] lubridate_1.9.2 forcats_1.0.0   readr_2.1.4     tidyr_1.3.0     tibble_3.2.1    ggplot2_3.4.3   tidyverse_2.0.0 stringr_1.5.0  
 [9] purrr_1.0.2     dplyr_1.1.2   
hadley commented 10 months ago

Could you please rework your reproducible example to use the reprex package ? That makes it easier to see both the input and the output, formatted in such a way that I can easily re-run in a local session. Thanks!

JustGitting commented 10 months ago

Hi @hadley,

I've updated my R to 4.3.2 (along with all older packages) to rule out any old packages/R problems. Reprex output is below, I hope it's worked correctly.

library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(purrr)
library(stringr)

value_fields <- c('Sepal.Length', 'Sepal.Width')
fnc_str <- c('sum', 'n')

pivot <- map2( value_fields, 
               fnc_str, 
               ~ iris %>%
                 group_by( Species ) %>% 
                 summarise( across(all_of(.x), 
                                   match.fun(.y), 
                                   .names = str_c("{.col}_", .y)), 
                            .groups = 'drop')
)  %>%
  reduce(inner_join)
#> Error in `map2()`:
#> ℹ In index: 2.
#> Caused by error in `summarise()`:
#> ℹ In argument: `across(all_of(.x), match.fun(.y), .names =
#>   str_c("{.col}_", .y))`.
#> ℹ In group 1: `Species = setosa`.
#> Caused by error in `across()`:
#> ! Can't compute column `Sepal.Width_n`.
#> Caused by error:
#> ! unused argument (Sepal.Width)
#> Backtrace:
#>      ▆
#>   1. ├─... %>% reduce(inner_join)
#>   2. ├─purrr::reduce(., inner_join)
#>   3. │ └─purrr:::reduce_impl(.x, .f, ..., .init = .init, .dir = .dir)
#>   4. │   └─purrr:::reduce_init(.x, .init, left = left, error_call = .purrr_error_call)
#>   5. │     └─rlang::is_empty(x)
#>   6. ├─purrr::map2(...)
#>   7. │ └─purrr:::map2_("list", .x, .y, .f, ..., .progress = .progress)
#>   8. │   ├─purrr:::with_indexed_errors(...)
#>   9. │   │ └─base::withCallingHandlers(...)
#>  10. │   ├─purrr:::call_with_cleanup(...)
#>  11. │   └─global .f(.x[[i]], .y[[i]], ...)
#>  12. │     └─iris %>% group_by(Species) %>% ...
#>  13. ├─dplyr::summarise(...)
#>  14. ├─dplyr:::summarise.grouped_df(...)
#>  15. │ └─dplyr:::summarise_cols(.data, dplyr_quosures(...), by, "summarise")
#>  16. │   ├─base::withCallingHandlers(...)
#>  17. │   └─dplyr:::map(quosures, summarise_eval_one, mask = mask)
#>  18. │     └─base::lapply(.x, .f, ...)
#>  19. │       └─dplyr (local) FUN(X[[i]], ...)
#>  20. │         ├─base::withCallingHandlers(...)
#>  21. │         └─mask$eval_all_summarise(quo)
#>  22. │           └─dplyr (local) eval()
#>  23. └─base::.handleSimpleError(...)
#>  24.   └─dplyr (local) h(simpleError(msg, call))
#>  25.     └─rlang::abort(msg, call = call("across"), parent = cnd)

Created on 2023-11-02 with reprex v2.0.2

hadley commented 10 months ago

Thanks. I made it a bit simpler to illustrate the key problem:

library(tidyverse)

vars <- c('Sepal.Length', 'Sepal.Width')
funs <- c('sum', 'n')

map2(vars, funs, function(var, fun) {
  iris |> summarise(across(all_of(var), match.fun(fun)))
})
#> Error in `map2()`:
#> ℹ In index: 2.
#> Caused by error in `summarise()`:
#> ℹ In argument: `across(all_of(var), match.fun(fun))`.
#> Caused by error in `across()`:
#> ! Can't compute column `Sepal.Width`.
#> Caused by error:
#> ! unused argument (Sepal.Width)

Created on 2023-11-02 with reprex v2.0.2

You can resolve this just by wrapping n() in a function that does take an argument:

library(tidyverse)

vars <- c('Sepal.Length', 'Sepal.Width')
funs <- list(sum, function(x) n())

map2(vars, funs, function(var, fun) {
  iris |> summarise(across(all_of(var), fun))
})
#> [[1]]
#>   Sepal.Length
#> 1        876.5
#> 
#> [[2]]
#>   Sepal.Width
#> 1         150

Created on 2023-11-02 with reprex v2.0.2

And then you can get better names by feeding across a named list:

library(tidyverse)

vars <- c('Sepal.Length', 'Sepal.Width')
funs <- list(list(sum = sum), list(n = function(x) n()))

map2(vars, funs, function(var, fun) {
  iris |> summarise(across(all_of(var), fun))
})
#> [[1]]
#>   Sepal.Length_sum
#> 1            876.5
#> 
#> [[2]]
#>   Sepal.Width_n
#> 1           150

Created on 2023-11-02 with reprex v2.0.2

JustGitting commented 10 months ago

Thanks @hadley!