mrc-ide / odin

ᚩ A DSL for describing and solving differential equations in R
https://mrc-ide.github.io/odin
Other
104 stars 13 forks source link

Implement min / max for arrays #293

Open thibautjombart opened 1 year ago

thibautjombart commented 1 year ago

Context

I am trying to retrieve the largest value of an array in odin. It seems min/max accept comma-separated, atomic numbers, but do not process arrays. I feel I must be missing something obvious... any tip welcome.

Reprex

This is not my actual use-case, but a minimal example generating a vector of Poisson-distributed values and reporting the sum, or maximum value for each iteration.

library(odin)

# The following using `sum` works as intended:

toy_expl <- "
  n_patches <- user()
  dim(x) <- n_patches
  initial(x[]) <- 0
  update(x[]) <- rpois(n_patches)
  y <- sum(x[])
  output(y) <- TRUE
"

set.seed(1)
odin(toy_expl)$new(n_patches = 3)$run(1:6)
#> Loading required namespace: pkgbuild
#> Generating model in c
#> ℹ Re-compiling odin18839483 (debug build)
#> ── R CMD INSTALL ───────────────────────────────────────────────────────────────
#> * installing *source* package ‘odin18839483’ ...
#> ** using staged installation
#> ** libs
#> gcc -I"/usr/local/lib/R/include" -DNDEBUG   -I/usr/local/include   -fpic  -g -O2  -UNDEBUG -Wall -pedantic -g -O0 -c odin.c -o odin.o
#> gcc -I"/usr/local/lib/R/include" -DNDEBUG   -I/usr/local/include   -fpic  -g -O2  -UNDEBUG -Wall -pedantic -g -O0 -c registration.c -o registration.o
#> gcc -shared -L/usr/local/lib/R/lib -L/usr/local/lib -o odin18839483.so odin.o registration.o -L/usr/local/lib/R/lib -lR
#> installing to /tmp/RtmpOPXyVT/devtools_install_1faa2a55c8ae/00LOCK-file1faa3d715d55/00new/odin18839483/libs
#> ** checking absolute paths in shared objects and dynamic libraries
#> * DONE (odin18839483)
#> ℹ Loading odin18839483
#>      step x[1] x[2] x[3]  y
#> [1,]    1    0    0    0  0
#> [2,]    2    1    4    2  7
#> [3,]    3    5    3    4 12
#> [4,]    4    3    3    4 10
#> [5,]    5    0    3    4  7
#> [6,]    6    4    3    5 12

# This generates an error, which I think is intended as 'max' does need
# several arguments, and is not designed for arrays:

toy_expl <- "
  n_patches <- user()
  dim(x) <- n_patches
  initial(x[]) <- 0
  update(x[]) <- rpois(n_patches)
  y <- max(x[])
  output(y) <- TRUE
"

set.seed(1)
odin(toy_expl)$new(n_patches = 3)$run(1:6)
#> Error: Expected 2 or more arguments in max call, but recieved 1
#>    y <- max(x[]) # (line 6)

# Trying a workaround to have 'y' return the maximum of all values in x[], but
# issues of self-referencing:

toy_expl <- "
  n_patches <- user()
  dim(x) <- n_patches
  initial(x[]) <- 0
  update(x[]) <- rpois(n_patches)
  y <- 0 # reasonable default here
  y <- if (x[i] > y) x[i] else y
  output(y) <- TRUE
"

set.seed(1)
odin(toy_expl)$new(n_patches = 3)$run(1:6)
#> Error: Self referencing expressions not allowed (except for arrays)
#>    y <- if (x[i] > y) x[i] else y # (line 7)

Created on 2023-05-10 with reprex v2.0.2

richfitz commented 1 year ago

Sorry, missed this. This is not terrible to do tbh. I doubt I'll get it done in a short period of time, but it's certainly something that won't break anything.

In the meantime you can do this in odin.dust with custom C code I believe: https://github.com/mrc-ide/odin.dust/blob/master/tests/testthat/test-odin-dust.R#L424-L439 and https://github.com/mrc-ide/odin.dust/blob/master/tests/testthat/include.cpp -- I'm not sure offhand if OG odin allows passing a vector in like this. If you need to have it sweep across array dimensions though that will be unsatisfying.