mrc-ide / odin

ᚩ A DSL for describing and solving differential equations in R
https://mrc-ide.github.io/odin
Other
106 stars 13 forks source link

Support Matrix Multiplication #291

Open CGMossa opened 1 year ago

CGMossa commented 1 year ago

Benchmarking an R+deSolve code against the equivalent odin code yielded a surprising result:

# A tibble: 2 × 13
  expression      min   median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory     time      
  <bch:expr> <bch:tm> <bch:tm>     <dbl> <bch:byt>    <dbl> <int> <dbl>   <bch:tm> <list> <list>     <list>    
1 odin_c        303ms    305ms      3.28    8.59MB      0       2     0      610ms <NULL> <Rprofmem> <bench_tm>
2 deSolve_r     121ms    121ms      8.23   38.04MB     24.7     1     3      121ms <NULL> <Rprofmem> <bench_tm>
# ℹ 1 more variable: gc <list>

I suspect the culprit is the lack of matrix multiplication in odin (or maybe I don't know how to invoke it). In the deSolve part:

between_sites <- transmission_rate * S * (foi_matrix %*% I)

However, in the odin part I do:

foi[, ] <- transmission_rate * foi_mat[i, j] * I[j]
delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,])
  deriv(S[]) <- -delta_transmission[i]
  deriv(I[]) <- +delta_transmission[i] - delta_recovery[i]

Here I've omitted the parts I don't think are necessary.

I'll work on a minimal testcase to check if this is indeed the problem.

Under details, I have more complete excerpts of my code:

Details

```r odin::odin({ foi[, ] <- foi_mat[i, j] * I[j] delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,]) delta_recovery[] <- recovery_rate * I[i] deriv(S[]) <- -delta_transmission[i] deriv(I[]) <- +delta_transmission[i] - delta_recovery[i] deriv(R[]) <- delta_recovery[i] source_id <- user() target_id <- user() Total[] <- S[i] + I[i] + R[i] output(source_prevalence) <- I[as.integer(source_id)] / Total[as.integer(source_id)] output(target_prevalence) <- I[as.integer(target_id)] / Total[as.integer(target_id)] transmission_rate <- user(0.05) recovery_rate <- user(0.01) S0[] <- user() I0[] <- user() foi_mat[,] <- user() initial(S[]) <- S0[i] initial(I[]) <- I0[i] initial(R[]) <- Total[i] - S0[i] - I0[i] dim(S0) <- user() dim(S) <- N dim(I) <- N dim(R) <- N dim(I0) <- N dim(foi_mat) <- c(N, N) dim(foi) <- c(N, N) dim(Total) <- N dim(delta_transmission) <- N dim(delta_recovery) <- N N <- length(S0) }, verbose = TRUE, validate = TRUE, target = "c", pretty = TRUE, skip_cache = FALSE) -> model_generator model <- model_generator$new(S0 = site_S, I0 = site_I, foi_mat = foi_matrix, source_id = as.integer(source_site_id), target_id = as.integer(target_site_id)) model$set_user(transmission_rate = 0.05, recovery_rate = 0.01) ``` Benchmarking: ```r bench::mark( odin_c = model$run(0:100), deSolve_r = { site_R <- site_S site_R[] <- 0 deSolve::ode( y = c(S = site_S, I = site_I, R = site_R), times = 0:100, func = function(time, state, parms) { with(parms, { S <- state[1:N] I <- state[(N + 1):(2 * N)] R <- state[(2 * N + 1):(3 * N)] Total <- S + I + R between_sites <- transmission_rate * S * (foi_matrix %*% I) source_prevalence <- I[[source_id]] / Total[[source_id]] target_prevalence <- I[[target_id]] / Total[[target_id]] list(c( dS = -transmission_rate * S * I - between_sites, dI = +transmission_rate * S * I + between_sites - recovery_rate * I, dR = recovery_rate * I ), source_prevalence = source_prevalence, target_prevalence = target_prevalence) }) }, parms = list( transmission_rate = 0.05, recovery_rate = 0.01, source_id = as.integer(source_site_id), target_id = as.integer(target_site_id), N = length(site_S) ) ) }, check = FALSE ) %>% print() ```

richfitz commented 1 year ago

unfortunately this is not that surprising - if the model is dominated by a matrix multiplication, then the version that uses a linear algebra library will be much faster.

Supporting this properly has been on the back burner for a long time (#38, #134, #213 - these mostly concern multinomial distributions but the syntactic issue in #134 is shared and is the primary blocker). The actual calling convention is not that bad, though it does mean that models need to have a working copy of gfortran to compile which is quite annoying in practice, particularly for people on macs

CGMossa commented 1 year ago

I'm glad you agree. For my use-case, I can circumvent this by being a little more clever about this. But to stick to this issue, and since you know this stuff already:

C:\Users\minin>R CMD config LAPACK_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRlapack

But if I just think about BLAS (whatever that is). First, it says:

 R packages that use these should have PKG_LIBS in src/Makevars include
   $(BLAS_LIBS) $(FLIBS)

So on my Windows machine it is

C:\Users\minin>R CMD config FLIBS
-lgfortran -lm -lquadmath

C:\Users\minin>R CMD config BLAS_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRblas

Then apparently dgemm is the Fortran routine that is supposed to do this, I've copied the prototype/header:


/* DGEMM - perform one of the matrix-matrix operations    */
/* C := alpha*op( A )*op( B ) + beta*C */
BLAS_extern void
F77_NAME(dgemm)(const char *transa, const char *transb, const int *m,
        const int *n, const int *k, const double *alpha,
        const double *a, const int *lda,
        const double *b, const int *ldb,
        const double *beta, double *c, const int *ldc 
        FCLEN FCLEN);

Finally, I've asked ChatGPT about this and it suggested this code for invoking this:

SEXP matrix_mult(SEXP a, SEXP b) {
  SEXP result;
  int nrow_a = nrows(a);
  int ncol_a = ncols(a);
  int nrow_b = nrows(b);
  int ncol_b = ncols(b);

  if (ncol_a != nrow_b) {
    error("Matrix dimensions do not match for multiplication.");
    return R_NilValue;
  }

  PROTECT(result = allocMatrix(REALSXP, nrow_a, ncol_b));

  double alpha = 1.0;
  double beta = 0.0;
  F77_CALL(dgemm)("N", "N", &nrow_a, &ncol_b, &ncol_a, &alpha, REAL(a), &nrow_a,
                  REAL(b), &nrow_b, &beta, REAL(result), &nrow_a);

  UNPROTECT(1);
  return result;
}

I don't know where these "N" comes from. But there are more than one of these, and this one is particularly matrix-matrix (while I apparently need matrix-vector). Presumably it is those SEXPTYPEs that the differentiator.

I've googled and BLAS should be supported on Mac. I don't know how that relates to LAPLACK, or where they are switched or changed.

Details

``` Usage: R CMD config [options] [VAR] Get the value of a basic R configure variable VAR which must be among those listed in the 'Variables' section below, or the header and library flags necessary for linking a front-end against R. Options: -h, --help print short help message and exit -v, --version print version info and exit --cppflags print pre-processor flags required to compile a C/C++ file as part of a front-end using R as a library --ldflags print linker flags needed for linking a front-end against the R library --no-user-files ignore customization files under ~/.R --no-site-files ignore site customization files under R_HOME/etc --all print names and values of all variables below Variables: AR command to make static libraries BLAS_LIBS flags needed for linking against external BLAS libraries CC C compiler command CFLAGS C compiler flags CC17 Ditto for the C17 or earlier compiler C17FLAGS CC23 Ditto for the C23 or later compiler C23FLAGS CPICFLAGS special flags for compiling C code to be included in a shared library CPPFLAGS C/C++ preprocessor flags, e.g. -I

if you have headers in a nonstandard directory CXX default compiler command for C++ code CXXFLAGS compiler flags for CXX CXXPICFLAGS special flags for compiling C++ code to be included in a shared library CXX11 compiler command for C++11 code CXX11STD flag used with CXX11 to enable C++11 support CXX11FLAGS further compiler flags for CXX11 CXX11PICFLAGS special flags for compiling C++11 code to be included in a shared library CXX14 compiler command for C++14 code CXX14STD flag used with CXX14 to enable C++14 support CXX14FLAGS further compiler flags for CXX14 CXX14PICFLAGS special flags for compiling C++14 code to be included in a shared library CXX17 compiler command for C++17 code CXX17STD flag used with CXX17 to enable C++17 support CXX17FLAGS further compiler flags for CXX17 CXX17PICFLAGS special flags for compiling C++17 code to be included in a shared library CXX20 compiler command for C++20 code CXX20STD flag used with CXX20 to enable C++20 support CXX20FLAGS further compiler flags for CXX20 CXX23 compiler command for C++23 code CXX23STD flag used with CXX23 to enable C++23 support CXX23FLAGS further compiler flags for CXX23 CXX23PICFLAGS special flags for compiling C++23 code to be included in a shared library DYLIB_EXT file extension (including '.') for dynamic libraries DYLIB_LD command for linking dynamic libraries which contain object files from a C or Fortran compiler only DYLIB_LDFLAGS special flags used by DYLIB_LD FC Fortran compiler command FFLAGS fixed-form Fortran compiler flags FCFLAGS free-form Fortran 9x compiler flags FLIBS linker flags needed to link Fortran code FPICFLAGS special flags for compiling Fortran code to be turned into a shared library JAR Java archive tool command JAVA Java interpreter command JAVAC Java compiler command JAVAH Java header and stub generator command JAVA_HOME path to the home of Java distribution JAVA_LIBS flags needed for linking against Java libraries JAVA_CPPFLAGS C preprocessor flags needed for compiling JNI programs LAPACK_LIBS flags needed for linking against external LAPACK libraries LIBnn location for libraries, e.g. 'lib' or 'lib64' on this platform LDFLAGS linker flags, e.g. -L if you have libraries in a nonstandard directory LTO LTO_FC LTO_LD flags for Link-Time Optimization MAKE Make command NM comand to display symbol tables OBJC Objective C compiler command OBJCFLAGS Objective C compiler flags RANLIB command to index static libraries SAFE_FFLAGS Safe (as conformant as possible) Fortran compiler flags SHLIB_CFLAGS additional CFLAGS used when building shared objects SHLIB_CXXFLAGS additional CXXFLAGS used when building shared objects SHLIB_CXXLD command for linking shared objects which contain object files from a C++ compiler (and CXX11 CXX14 CXX17 CXX20 CXX23) SHLIB_CXXLDFLAGS special flags used by SHLIB_CXXLD (and CXX11 CXX14 CXX17 CXX20 CXX23) SHLIB_EXT file extension (including '.') for shared objects SHLIB_FFLAGS additional FFLAGS used when building shared objects SHLIB_LD command for linking shared objects which contain object files from a C or Fortran compiler only SHLIB_LDFLAGS special flags used by SHLIB_LD TCLTK_CPPFLAGS flags needed for finding the tcl.h and tk.h headers TCLTK_LIBS flags needed for linking against the Tcl and Tk libraries Windows only: COMPILED_BY name and version of compiler used to build R LOCAL_SOFT absolute path to '/usr/local' software collection R_TOOLS_SOFT absolute path to 'R tools' software collection OBJDUMP command to dump objects Report bugs at . ```

CGMossa commented 1 year ago

Okay, I've also found this snippet here that might be helpful:

https://cran.r-project.org/bin/macosx/RMacOSX-FAQ.html#Which-BLAS-is-used-and-how-can-it-be-changed_003f

And of course this:

https://cran.r-project.org/doc/manuals/r-release/R-admin.html#Linear-algebra

richfitz commented 1 year ago

Thanks - that part is straightforward and we do it elsewhere (for example https://github.com/mrc-ide/eigen1/blob/master/src/util.c#L16-L17) - the pain comes when users have not correctly installed the fortran parts of the toolchain - and on macs that changes every couple of years as apple and R-core change how things get installed.

The blocker on this is the odin syntax, and that's been unresolved for about 5 years so I doubt we will get to it soon!

CGMossa commented 1 year ago

Good. I won't comment on the syntax just yet.. Especially since I don't know anything about parsers. I guess the problem is that right now the line order doesn't matter, but for the three-step definition it would need to? In any case, thanks for indulging this conversation.

I guess, for my personal understanding, on Windows we have Rblas.dll, and I had hoped it was possible to just link to that, and not need a Fortran compiler. On Windows however, we have Rtools, and most likely it also contains Fortran compiler.. So I don't really have experience with this. I would have guessed -shared plus linking to those Rblas.dll or equivalent elsewhere would have been enough...

richfitz commented 1 year ago

Windows tends to be fine because R core controls the whole toolchain. On mac, at linking, you get issues if libgfortran is not found

Line order won't matter for this either - the intention is to support y <- A %*% x and convert that to the appropriate blas call based on what we know about y, A and x. The issue is when (inevitably) people want to apply these transformations to higher order objects, so looping over part of y at each operation, so we're thinking about things like:

y[., ] <- A[j, ., .] %*% x[., j]

at the moment