Open CGMossa opened 1 year ago
unfortunately this is not that surprising - if the model is dominated by a matrix multiplication, then the version that uses a linear algebra library will be much faster.
Supporting this properly has been on the back burner for a long time (#38, #134, #213 - these mostly concern multinomial distributions but the syntactic issue in #134 is shared and is the primary blocker). The actual calling convention is not that bad, though it does mean that models need to have a working copy of gfortran to compile which is quite annoying in practice, particularly for people on macs
I'm glad you agree. For my use-case, I can circumvent this by being a little more clever about this. But to stick to this issue, and since you know this stuff already:
R CMD config
to compile with the right flags on different platforms:C:\Users\minin>R CMD config LAPACK_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRlapack
But if I just think about BLAS (whatever that is). First, it says:
R packages that use these should have PKG_LIBS in src/Makevars include
$(BLAS_LIBS) $(FLIBS)
So on my Windows machine it is
C:\Users\minin>R CMD config FLIBS
-lgfortran -lm -lquadmath
C:\Users\minin>R CMD config BLAS_LIBS
-LC:/Users/minin/scoop/apps/r/current/bin/x64 -lRblas
Then apparently dgemm
is the Fortran routine that is supposed to do this,
I've copied the prototype/header:
/* DGEMM - perform one of the matrix-matrix operations */
/* C := alpha*op( A )*op( B ) + beta*C */
BLAS_extern void
F77_NAME(dgemm)(const char *transa, const char *transb, const int *m,
const int *n, const int *k, const double *alpha,
const double *a, const int *lda,
const double *b, const int *ldb,
const double *beta, double *c, const int *ldc
FCLEN FCLEN);
Finally, I've asked ChatGPT about this and it suggested this code for invoking this:
SEXP matrix_mult(SEXP a, SEXP b) {
SEXP result;
int nrow_a = nrows(a);
int ncol_a = ncols(a);
int nrow_b = nrows(b);
int ncol_b = ncols(b);
if (ncol_a != nrow_b) {
error("Matrix dimensions do not match for multiplication.");
return R_NilValue;
}
PROTECT(result = allocMatrix(REALSXP, nrow_a, ncol_b));
double alpha = 1.0;
double beta = 0.0;
F77_CALL(dgemm)("N", "N", &nrow_a, &ncol_b, &ncol_a, &alpha, REAL(a), &nrow_a,
REAL(b), &nrow_b, &beta, REAL(result), &nrow_a);
UNPROTECT(1);
return result;
}
I don't know where these "N"
comes from.
But there are more than one of these, and this one is particularly matrix-matrix
(while I apparently need matrix-vector).
Presumably it is those SEXPTYPEs that the differentiator.
I've googled and BLAS should be supported on Mac. I don't know how that relates to LAPLACK, or where they are switched or changed.
``` Usage: R CMD config [options] [VAR] Get the value of a basic R configure variable VAR which must be among those listed in the 'Variables' section below, or the header and library flags necessary for linking a front-end against R. Options: -h, --help print short help message and exit -v, --version print version info and exit --cppflags print pre-processor flags required to compile a C/C++ file as part of a front-end using R as a library --ldflags print linker flags needed for linking a front-end against the R library --no-user-files ignore customization files under ~/.R --no-site-files ignore site customization files under R_HOME/etc --all print names and values of all variables below Variables: AR command to make static libraries BLAS_LIBS flags needed for linking against external BLAS libraries CC C compiler command CFLAGS C compiler flags CC17 Ditto for the C17 or earlier compiler C17FLAGS CC23 Ditto for the C23 or later compiler C23FLAGS CPICFLAGS special flags for compiling C code to be included in a shared library CPPFLAGS C/C++ preprocessor flags, e.g. -I
Okay, I've also found this snippet here that might be helpful:
And of course this:
https://cran.r-project.org/doc/manuals/r-release/R-admin.html#Linear-algebra
Thanks - that part is straightforward and we do it elsewhere (for example https://github.com/mrc-ide/eigen1/blob/master/src/util.c#L16-L17) - the pain comes when users have not correctly installed the fortran parts of the toolchain - and on macs that changes every couple of years as apple and R-core change how things get installed.
The blocker on this is the odin syntax, and that's been unresolved for about 5 years so I doubt we will get to it soon!
Good. I won't comment on the syntax just yet.. Especially since I don't know anything about parsers. I guess the problem is that right now the line order doesn't matter, but for the three-step definition it would need to? In any case, thanks for indulging this conversation.
I guess, for my personal understanding, on Windows we have Rblas.dll, and I had hoped it was possible to just link to that, and not need a Fortran compiler. On Windows however, we have Rtools, and most likely it also contains Fortran compiler.. So I don't really have experience with this. I would have guessed -shared
plus linking to those Rblas.dll
or equivalent elsewhere would have been enough...
Windows tends to be fine because R core controls the whole toolchain. On mac, at linking, you get issues if libgfortran is not found
Line order won't matter for this either - the intention is to support y <- A %*% x
and convert that to the appropriate blas call based on what we know about y, A and x. The issue is when (inevitably) people want to apply these transformations to higher order objects, so looping over part of y
at each operation, so we're thinking about things like:
y[., ] <- A[j, ., .] %*% x[., j]
at the moment
Benchmarking an R+deSolve code against the equivalent odin code yielded a surprising result:
I suspect the culprit is the lack of matrix multiplication in
odin
(or maybe I don't know how to invoke it). In thedeSolve
part:However, in the
odin
part I do:Here I've omitted the parts I don't think are necessary.
I'll work on a minimal testcase to check if this is indeed the problem.
Under details, I have more complete excerpts of my code:
Details
```r odin::odin({ foi[, ] <- foi_mat[i, j] * I[j] delta_transmission[] <- transmission_rate * S[i] * I[i] + transmission_rate * S[i] * sum(foi[i,]) delta_recovery[] <- recovery_rate * I[i] deriv(S[]) <- -delta_transmission[i] deriv(I[]) <- +delta_transmission[i] - delta_recovery[i] deriv(R[]) <- delta_recovery[i] source_id <- user() target_id <- user() Total[] <- S[i] + I[i] + R[i] output(source_prevalence) <- I[as.integer(source_id)] / Total[as.integer(source_id)] output(target_prevalence) <- I[as.integer(target_id)] / Total[as.integer(target_id)] transmission_rate <- user(0.05) recovery_rate <- user(0.01) S0[] <- user() I0[] <- user() foi_mat[,] <- user() initial(S[]) <- S0[i] initial(I[]) <- I0[i] initial(R[]) <- Total[i] - S0[i] - I0[i] dim(S0) <- user() dim(S) <- N dim(I) <- N dim(R) <- N dim(I0) <- N dim(foi_mat) <- c(N, N) dim(foi) <- c(N, N) dim(Total) <- N dim(delta_transmission) <- N dim(delta_recovery) <- N N <- length(S0) }, verbose = TRUE, validate = TRUE, target = "c", pretty = TRUE, skip_cache = FALSE) -> model_generator model <- model_generator$new(S0 = site_S, I0 = site_I, foi_mat = foi_matrix, source_id = as.integer(source_site_id), target_id = as.integer(target_site_id)) model$set_user(transmission_rate = 0.05, recovery_rate = 0.01) ``` Benchmarking: ```r bench::mark( odin_c = model$run(0:100), deSolve_r = { site_R <- site_S site_R[] <- 0 deSolve::ode( y = c(S = site_S, I = site_I, R = site_R), times = 0:100, func = function(time, state, parms) { with(parms, { S <- state[1:N] I <- state[(N + 1):(2 * N)] R <- state[(2 * N + 1):(3 * N)] Total <- S + I + R between_sites <- transmission_rate * S * (foi_matrix %*% I) source_prevalence <- I[[source_id]] / Total[[source_id]] target_prevalence <- I[[target_id]] / Total[[target_id]] list(c( dS = -transmission_rate * S * I - between_sites, dI = +transmission_rate * S * I + between_sites - recovery_rate * I, dR = recovery_rate * I ), source_prevalence = source_prevalence, target_prevalence = target_prevalence) }) }, parms = list( transmission_rate = 0.05, recovery_rate = 0.01, source_id = as.integer(source_site_id), target_id = as.integer(target_site_id), N = length(site_S) ) ) }, check = FALSE ) %>% print() ```