r-lib / rlang

Low-level API for programming with R
https://rlang.r-lib.org
Other
491 stars 131 forks source link

`!!!` and `call_match()` #1654

Open t-kalinowski opened 9 months ago

t-kalinowski commented 9 months ago

In keras, I naively tried unpacking arguments with !!! like this, which resulted in an error:

initializers <- list(kernel_initializer = tf$ones, 
                     bias_initializer = tf$ones)

layer_dense(model, 1, activation = "relu", !!!initializers)

The issue is that !!!initializer is matched to the 4th positional argument in layer_dense(), rather than being spliced in. layer_dense() uses match.call() under the hood, and this is what match.call() returns:

layer_dense(object = model, units = 1, activation = "relu", use_bias = !!!initializers)

rlang::match_call() returns the same as base::match.call() in this situation. It would be nice if rlang::match_call() could do better than match.call() here, splice in the arguments supplied with !!!, and return instead:

layer_dense(object = model, units = 1, activation = "relu", 
            kernel_initializer = tf$ones, bias_initializer = tf$ones)
t-kalinowski commented 9 months ago

I ended up working around this by first capturing the call w/ exprs(), (which supports injection with !!!), then passing that to match.call(). Something like this (simplified from the actual version):

capture_args2 <- function(ignore = NULL, 
                          envir = parent.frame(), 
                          fn = sys.function(-1)) {

  cl0 <- cl <- sys.call(-1)

  # first defuse rlang !!! and := in calls
  cl[[1L]] <- quote(rlang::exprs)
  cl_exprs <- eval(cl, parent.frame(2L))

  # build up a call to base::list() using the exprs
  cl <- as.call(c(list, cl_exprs))

  # match.call()
  cl <- match.call(fn, cl,
                   expand.dots = !"..." %in% ignore,
                   envir = parent.frame(2L))

  # filter out args to ignore
  for(ig in intersect(names(cl), ignore))
    cl[[ig]] <- NULL

  # eval and capture args
  eval(cl, envir = parent.frame(2L))

}

layer_dense <- function(....) {
  args <- capture_args()
  ....
}
lionel- commented 9 months ago

Worth noting that this workaround produces different implicit injection semantics than we normally have, as implicit injection only works in dots (also called dynamic dots). The implicit injection created by capture_args() will also apply to named arguments, making it closer to explicit injection (as with exprs() or inject()).

That said we do have a long standing plan of supporting injection in named arguments, e.g. for aes(x = , y = , ...), or for rlang::env(env = , ...). I'm not sure how call_match() fits into this picture. The main problem I see with call_match() is dots expansion.

See how dots evaluation is problematic with a call matching approach:

f <- function(...) {
  foo <- 1
  g(..., foo)
}
g <- function(...) {
  bar <- 2
  layer_dense(..., bar)
}

local({
  qux <- 0
  f(qux)
})
#> Error in eval(cl, envir = parent.frame(2L)) : object 'qux' not found

We could introduce quosures or quosure-like expressions in the matched call but then you get issues with labelling NSE.

t-kalinowski commented 9 months ago

Thank You @lionel-!

Swapping out exprs() and eval() with quos() and eval_tidy() seems to allow for resolving arguments in each of their respective environments, as well as splicing with !!! from any of the calling environments:

capture_args2 <- function(ignore = NULL, 
                          envir = parent.frame(), 
                          fn = sys.function(-1)) {
  cl0 <- cl <- sys.call(-1)

  # first defuse rlang !!! and := in calls
  cl[[1L]] <- quote(rlang::quos)
  cl_exprs <- eval(cl, parent.frame(2L))

  # build up a call to base::list() using the exprs
  cl <- as.call(c(list, cl_exprs))

  # match.call()
  cl <- match.call(fn, cl,
                   expand.dots = !"..." %in% ignore,
                   envir = parent.frame(2L))

  # filter out args to ignore
  for(ig in intersect(names(cl), ignore))
    cl[[ig]] <- NULL

  # eval and capture args
  rlang::eval_tidy(cl, env = parent.frame(2L))

}

initializers <- list(kernel_initializer = "ones", 
                     bias_initializer = "ones")

layer_dense <- function(object, units, activation = NULL, use_bias = TRUE,
                        kernel_initializer = "glorot_uniform", bias_initializer = "zeros", 
                        kernel_regularizer = NULL, bias_regularizer = NULL, 
                        activity_regularizer = NULL,
                        kernel_constraint = NULL, bias_constraint = NULL, 
                        ..., mask = NULL) {

  args2 <- capture_args2()
  args2
}

f <- function(...) {
  foo <- 1
  g(..., foo)
}
g <- function(...) {
  bar <- 2
  layer_dense(..., bar)
}

local({
  qux <- 0
  str(f(qux))
  str(f(qux, !!!initializers))
})
#> List of 3
#>  $ object    : num 0
#>  $ units     : num 1
#>  $ activation: num 2
#> List of 5
#>  $ object            : num 0
#>  $ units             : num 1
#>  $ activation        : num 2
#>  $ kernel_initializer: chr "ones"
#>  $ bias_initializer  : chr "ones"

# !!! injection in f
f <- function(...) {
  foo <- 1
  g(..., foo, !!!initializers)
}
g <- function(...) {
  bar <- 2
  layer_dense(..., bar)
}
local({
  qux <- 0
  str(f(qux))
})
#> List of 5
#>  $ object            : num 0
#>  $ units             : num 1
#>  $ activation        : num 2
#>  $ kernel_initializer: chr "ones"
#>  $ bias_initializer  : chr "ones"

# !!! injection in g
f <- function(...) {
  foo <- 1
  g(..., foo)
}
g <- function(...) {
  bar <- 2
  layer_dense(..., bar, !!!initializers)
}
local({
  qux <- 0
  str(f(qux))
})
#> List of 5
#>  $ object            : num 0
#>  $ units             : num 1
#>  $ activation        : num 2
#>  $ kernel_initializer: chr "ones"
#>  $ bias_initializer  : chr "ones"

Created on 2023-10-04 with reprex v2.0.2

lionel- commented 9 months ago

Cool!