tidyverse / dbplyr

Database (DBI) backend for dplyr
https://dbplyr.tidyverse.org
Other
477 stars 173 forks source link

across (via across_funs) errors if .fns is a list and uses standard evaluation #662

Closed swnydick closed 2 years ago

swnydick commented 3 years ago

I am trying to write an API where the user can specify one or more aggregation functions to apply to one or more columns of a database. Therefore, the function must be a programmed argument and ideally can be a list of character strings. Running this on the tbl itself works (see below), but running it on the database (sqlite, mysql, etc.) errored. There might be a work-around (although I haven't figured it out), but this behavior still feels unintended.

library(dbplyr)
library(dplyr, warn.conflicts = FALSE)

sessionInfo()
#> R version 4.1.0 (2021-05-18)
#> Platform: x86_64-apple-darwin17.0 (64-bit)
#> Running under: macOS Mojave 10.14.6
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] dplyr_1.0.6  dbplyr_2.1.1
#> 
#> loaded via a namespace (and not attached):
#>  [1] rstudioapi_0.13   knitr_1.33        magrittr_2.0.1    tidyselect_1.1.1 
#>  [5] R6_2.5.0          rlang_0.4.11      fansi_0.4.2       stringr_1.4.0    
#>  [9] highr_0.9         tools_4.1.0       xfun_0.23         utf8_1.2.1       
#> [13] cli_2.5.0         DBI_1.1.1         withr_2.4.2       ellipsis_0.3.2   
#> [17] htmltools_0.5.1.1 yaml_2.2.1        digest_0.6.27     assertthat_0.2.1 
#> [21] tibble_3.1.2      lifecycle_1.0.0   crayon_1.4.1      purrr_0.3.4      
#> [25] vctrs_0.3.8       fs_1.5.0          glue_1.4.2        evaluate_0.14    
#> [29] rmarkdown_2.8     reprex_2.0.0      stringi_1.6.2     compiler_4.1.0   
#> [33] pillar_1.6.1      generics_0.1.0    pkgconfig_2.0.3

db <- memdb_frame(iris)

st <- list("mean", "sd")

# works (single function after unquoting)
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = (!!st)[[1]]))
#> Warning: Missing values are always removed in SQL.
#> Use `mean(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> # Source:   lazy query [?? x 1]
#> # Database: sqlite 3.35.5 [:memory:]
#>   Sepal.Length
#>          <dbl>
#> 1         5.84

# works (hard-coding mean/sd directly)
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list("mean", "sd")))
#> Warning: Missing values are always removed in SQL.
#> Use `sd(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> # Source:   lazy query [?? x 2]
#> # Database: sqlite 3.35.5 [:memory:]
#>   `Sepal.Length_"mean"` `Sepal.Length_"sd"`
#>                   <dbl>               <dbl>
#> 1                  5.84               0.828

# breaks (using object)
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = (!!st)))
#> Error: `.fns` argument to dbplyr::across() must be a NULL, a function name, formula, or list

# works
iris %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = (!!st)))
#>   Sepal.Length_1 Sepal.Length_2
#> 1       5.843333      0.8280661

Created on 2021-06-10 by the reprex package (v2.0.0)

mgirlich commented 3 years ago

You were nearly there:

library(dplyr, warn.conflicts = FALSE)
library(dbplyr, warn.conflicts = FALSE)

db <- memdb_frame(iris)

# Note: a list of characters (currently) produces ugly names
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list("mean", "sd")))
#> Warning: Missing values are always removed in SQL.
#> Use `mean(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> Warning: Missing values are always removed in SQL.
#> Use `sd(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> # Source:   lazy query [?? x 2]
#> # Database: sqlite 3.35.5 [:memory:]
#>   `Sepal.Length_"mean"` `Sepal.Length_"sd"`
#>                   <dbl>               <dbl>
#> 1                  5.84               0.828
# and they are different for local data frames
iris %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list("mean", "sd")))
#>   Sepal.Length_1 Sepal.Length_2
#> 1       5.843333      0.8280661

# you need to explicitly use `list()`
st <- list("mean", "sd")
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list(!!!st)))
#> Warning: Missing values are always removed in SQL.
#> Use `sd(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> # Source:   lazy query [?? x 2]
#> # Database: sqlite 3.35.5 [:memory:]
#>   `Sepal.Length_"mean"` `Sepal.Length_"sd"`
#>                   <dbl>               <dbl>
#> 1                  5.84               0.828

# nicer names when using symbols instead of character
st2 <- lapply(st, sym)
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list(!!!st2)))
#> Warning: Missing values are always removed in SQL.
#> Use `sd(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> # Source:   lazy query [?? x 2]
#> # Database: sqlite 3.35.5 [:memory:]
#>   Sepal.Length_mean Sepal.Length_sd
#>               <dbl>           <dbl>
#> 1              5.84           0.828

# or simply use a character vector
st_char <- unlist(st)
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = !!st_char))
#> Warning: Missing values are always removed in SQL.
#> Use `sd(x, na.rm = TRUE)` to silence this warning
#> This warning is displayed only once per session.
#> # Source:   lazy query [?? x 2]
#> # Database: sqlite 3.35.5 [:memory:]
#>   Sepal.Length_mean Sepal.Length_sd
#>               <dbl>           <dbl>
#> 1              5.84           0.828

Created on 2021-06-11 by the reprex package (v2.0.0)

swnydick commented 3 years ago

Thanks for the work around. The character list was just a very simplified example for reprex purposes (I noticed the odd names, but it generally seems like fewer lines of code is better for reprex, so I ignored that because it wasn't really important here). That said, the comment was less "I could find a work around" than "the behavior is inconsistent". If anything, the error is incredibly misleading in the case of unquoting an assigned list.

#> Error: `.fns` argument to dbplyr::across() must be a NULL, a function name, formula, or list

Ideally using NSE and unquoting SE should result in the same behavior in most situations (although there are obvious exceptions, such as needing to use symbols for lazy evaluation in SE rather than using the function call, especially with dbplyr where the names are translated). But I couldn't see any reason why

st <- c("mean", "sd")
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list("mean", "sd")))
db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = !!st))

would result in different behavior. In any case, I even went the base-R way to see if that would work (ignoring known environment issues with the base R eval/substitute), but that caused exactly the same problem.

library(dplyr, warn.conflicts = FALSE)
library(dbplyr, warn.conflicts = FALSE)

db   <- memdb_frame(iris)

st   <- syms(list("mean", "sd"))
expr <- substitute(db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = stats)), list(stats = st))
expr
#> db %>% summarize(across(.cols = all_of("Sepal.Length"), .fns = list(
#>     mean, sd)))
eval(expr)
#> Error: `.fns` argument to dbplyr::across() must be a NULL, a function name, formula, or list

Created on 2021-06-11 by the reprex package (v2.0.0)

So across/across_funs is forcing .fns = list() explicitly ("list" must be NSE and won't work with substitution, even if the stuff inside can be unquoted SE, which your work around showed), which seems antithetical to the purpose of tidyeval.