tidyverse / funs

Collection of low-level functions for working with vctrs
Other
34 stars 7 forks source link

recode_when() #66

Open romainfrancois opened 3 years ago

romainfrancois commented 3 years ago

closes #60

using the magic default~ for now:

library(magrittr)
library(funs)

alphabet <- c(letters[1:10], NA)
alphabet %>% 
  recode_when(
    c("a", "e", "i", "o", "u") ~ "vowel", 
    NA                         ~ "missing", 
    default                    ~ "consonent"
  )
#>  [1] "vowel"     "consonent" "consonent" "consonent" "vowel"     "consonent"
#>  [7] "consonent" "consonent" "vowel"     "consonent" "missing"

Created on 2021-05-06 by the reprex package (v2.0.0)

romainfrancois commented 3 years ago

Wondering about the relationship with na_if() which could be:

na_if <- function(x, y) {
   recode_when(x, y ~ NA)
}

but then perhaps y could be allowed to not be just values ?

romainfrancois commented 3 years ago

Trying to separate the recode() and the when() :

library(magrittr)
library(funs)

patch <- recode_when

1:10 %>% 
  patch(
    c(1, 2, 3) ~ 3, 
    when(~.x > 7) ~ 7
  )
#>  [1] 3 3 3 4 5 6 7 7 7 7

Created on 2021-05-06 by the reprex package (v2.0.0)

which perhaps eliminate the need for na_if() ?

romainfrancois commented 3 years ago

A thing like when() is only useful to allow both values and predicate on the lhs. Maybe this is the same as where()

romainfrancois commented 3 years ago
library(dplyr, warn.conflicts = FALSE)
library(funs, warn.conflicts = FALSE)

# patch a single thing
band_instruments %>%
  mutate(
    name = patch(name, 
      when(plays == "guitar", paste0(name, "!")), 
      when(plays == "bass", paste0(name, "@"))
    )
  )
#> # A tibble: 3 x 2
#>   name   plays 
#>   <chr>  <chr> 
#> 1 John!  guitar
#> 2 Paul@  bass  
#> 3 Keith! guitar
romainfrancois commented 3 years ago

when() could also work with a case() function, to get a "case when":

library(funs)

x <- 1:50
case(
  when(x %% 35 == 0, "fizz buzz"), 
  when(x %% 5 == 0 , "fizz"), 
  when(x %% 7 == 0 , "buzz"), 

  when(default     , as.character(x))
)
#>  [1] "1"         "2"         "3"         "4"         "fizz"      "6"        
#>  [7] "buzz"      "8"         "9"         "fizz"      "11"        "12"       
#> [13] "13"        "buzz"      "fizz"      "16"        "17"        "18"       
#> [19] "19"        "fizz"      "buzz"      "22"        "23"        "24"       
#> [25] "fizz"      "26"        "27"        "buzz"      "29"        "fizz"     
#> [31] "31"        "32"        "33"        "34"        "fizz buzz" "36"       
#> [37] "37"        "38"        "39"        "fizz"      "41"        "buzz"     
#> [43] "43"        "44"        "fizz"      "46"        "47"        "48"       
#> [49] "buzz"      "fizz"

Created on 2021-05-19 by the reprex package (v2.0.0)

romainfrancois commented 3 years ago

I guess the whole patch() + cur_data() or patch() + across() is heavy, and the current data could be given implicitly.

library(dplyr, warn.conflicts = FALSE)
library(funs, warn.conflicts = FALSE)
pick <- function(...) {
  call <- rlang::call2(patch, cur_data(), !!!enexprs(...))
  eval(call, rlang::caller_env())
}

d <- tibble(x = 1:4, y = 1:4)
d %>% 
  mutate(
    pick(
      when(x < 2 , x = -x, y = -x), 
      when(x > 3 , x = 0)
    )
  )
#> # A tibble: 4 x 2
#>       x     y
#>   <int> <int>
#> 1    -1    -1
#> 2     2     2
#> 3     3     3
#> 4     0     4

Created on 2021-05-20 by the reprex package (v2.0.0)

romainfrancois commented 3 years ago

Might relate to mutate_when() https://github.com/tidyverse/dplyr/issues/4050

library(dplyr, warn.conflicts = FALSE)
library(funs, warn.conflicts = FALSE)

d <- tibble(x = 1:4, y = 1:4)
d %>% 
  mutate(
    patch(cur_data(),
      when(x < 2 , data.frame(x = -x, y = -x))
    )
  )
#> # A tibble: 4 x 2
#>       x     y
#>   <int> <int>
#> 1    -1    -1
#> 2     2     2
#> 3     3     3
#> 4     4     4

Created on 2021-05-20 by the reprex package (v2.0.0)

romainfrancois commented 3 years ago

Or maybe a dplyr::amend() sibling to mutate() that would allow:

d %>%
  amend(
    when( <some expression that select rows>, <things to do within these rows>), 
    when(x > 2, y = 3)
  )

i.e.

d %>% 
  amend(
     when(x > 1, z = 2), 
     when(x > 2, z = 3)
  )

would do (in spirit):

d$z[d$x > 1] <- 2
d$z[d$x > 2] <- 3
DavisVaughan commented 3 years ago

For dplyr::amend(), it might be useful to limit it to 1 row selection predicate that can update multiple columns. Then you don't need when(). Framing it as a mutate() gives you a lot of benefits (auto group handling, for one thing).

library(dplyr)
library(rlang)
library(vctrs)

amend <- function(.data, .p, ...) {
  .p <- enquo(.p)
  quos <- enquos(...)

  n <- length(quos)
  names <- names(quos)
  syms <- syms(names)

  if (!is_named2(quos)) {
    abort("All elements of `...` must be named.")
  }
  if (!all(names %in% names(.data))) {
    abort("Can't amend columns that don't exist.")
  }

  exprs <- vector("list", length = n)
  names(exprs) <- names

  for (i in seq_len(n)) {
    name <- names[[i]]
    sym <- syms[[i]]
    quo <- quos[[i]]
    exprs[[i]] <- expr(vec_assign(!!sym, `dplyr:::i`, !!quo, x_arg = !!name))
  }

  out <- mutate(.data, `dplyr:::i` = !!.p, !!!exprs)
  out <- select(out, -`dplyr:::i`)
  out
}

mtcars <- as_tibble(mtcars)

mtcars %>% 
  amend(mpg > mean(mpg), mpg = mean(mpg), cyl = -1)
#> # A tibble: 32 x 11
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  20.1    -1  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  20.1    -1  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  20.1    -1  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  20.1    -1  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  20.1    -1  147.    62  3.69  3.19  20       1     0     4     2
#>  9  20.1    -1  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
#> # … with 22 more rows

# works with per group predicates and expressions
mtcars %>%
  group_by(cyl) %>%
  amend(mpg > mean(mpg), mpg = mean(mpg))
#> # A tibble: 32 x 11
#> # Groups:   cyl [3]
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  19.7     6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  19.7     6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  19.7     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  15.1     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
#> # … with 22 more rows

# Can't amend without named dots
mtcars %>% amend(mpg > mean(mpg), -1)
#> Error: All elements of `...` must be named.

# Can't amend columns that dont exist
mtcars %>% amend(mpg > mean(mpg), foo = -1)
#> Error: Can't amend columns that don't exist.

# Can technically amend by row position (per group)?
mtcars %>% group_by(cyl) %>% amend(1, mpg = -1)
#> # A tibble: 32 x 11
#> # Groups:   cyl [3]
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  -1       6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  21       6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  -1       4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  -1       8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
#> # … with 22 more rows

Created on 2021-05-26 by the reprex package (v1.0.0)

DavisVaughan commented 3 years ago

Ah but vec_assign() is probably not the right usage, since the size must match the number of locations to update.

mtcars %>% 
  amend(mpg > mean(mpg), mpg = mpg + 1)
#> Error: Problem with `mutate()` column `mpg`.
#> ℹ `mpg = vec_assign(mpg, `dplyr:::i`, mpg + 1, x_arg = "mpg")`.
#> x Can't recycle input of size 32 to size 14.

So maybe if_else() rather than vec_assign()? That would also restrict it to a logical .p, which feels right.

library(dplyr)
library(rlang)
library(vctrs)

amend <- function(.data, .p, ...) {
  .p <- enquo(.p)
  quos <- enquos(...)

  n <- length(quos)
  names <- names(quos)
  syms <- syms(names)

  if (!is_named2(quos)) {
    abort("All elements of `...` must be named.")
  }
  if (!all(names %in% names(.data))) {
    abort("Can't amend columns that don't exist.")
  }

  exprs <- vector("list", length = n)
  names(exprs) <- names

  for (i in seq_len(n)) {
    exprs[[i]] <- expr(if_else(`dplyr:::i`, !!quos[[i]], !!syms[[i]]))
  }

  out <- mutate(.data, `dplyr:::i` = !!.p, !!!exprs)
  out <- select(out, -`dplyr:::i`)
  out
}

mtcars <- as_tibble(mtcars)

# works with expressions where the length of the expression result equals
# the number of rows in the input
mtcars %>% 
  amend(mpg > mean(mpg), mpg = mpg + 1)
#> # A tibble: 32 x 11
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  22       6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  22       6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  23.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  22.4     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  25.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  23.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
#> # … with 22 more rows

# works with per group predicates where the rhs of the expression is the length
# of the group
mtcars %>% 
  group_by(cyl) %>%
  amend(mpg > mean(mpg), mpg = mpg + mean(mpg))
#> # A tibble: 32 x 11
#> # Groups:   cyl [3]
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  40.7     6  160    110  3.9   2.62  16.5     0     1     4     4
#>  2  40.7     6  160    110  3.9   2.88  17.0     0     1     4     4
#>  3  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1
#>  4  41.1     6  258    110  3.08  3.22  19.4     1     0     3     1
#>  5  33.8     8  360    175  3.15  3.44  17.0     0     0     3     2
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4
#>  8  24.4     4  147.    62  3.69  3.19  20       1     0     4     2
#>  9  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4
#> # … with 22 more rows

# Cant amend by row position
mtcars %>% group_by(cyl) %>% amend(1, mpg = mpg)
#> Error: Problem with `mutate()` column `mpg`.
#> ℹ `mpg = if_else(`dplyr:::i`, mpg, mpg)`.
#> x `condition` must be a logical vector, not a double vector.
#> ℹ The error occurred in group 1: cyl = 4.
romainfrancois commented 3 years ago

Chances are you would want multiple amendments, e.g. do mpg = mpg + 1 for "big values", mpg = mpg - 1 for "small values"

mtcars %>% 
  amend(mpg > mean(mpg), mpg = mpg + 1) %>%
  amend(mpg < mean(mpg), mpg = mpg - 1)

The mpg < mean(mpg) and mpg - 1) of the second amend() would then be based on the updated mpg with greater mean.

Also in terms of readability, I feel the when()adds something:

mtcars %>% 
  amend(
    when(mpg > mean(mpg), mpg = mpg + 1), 
    when(mpg < mean(mpg), mpg = mpg - 1)
  )

Also I think this is somewhat orthogonal to dplyr usual grouping w/ mutate() where you have specified grouping upstream and you apply the same code to each group.

Here, the grouping gets specified inline, and a different thing is done on each group. I also believe that with this, results of each ... in when() should tidy recycle to the size implied by mpg > mean(mpg) i.e. mpg = mpg + 1 is evaluated in a data mask where mpg stands for mpg[mpg > mean(mpg)]