mixOmicsTeam / mixOmics

Development repository for the Bioconductor package 'mixOmics '
http://mixomics.org/
153 stars 51 forks source link

Repeated matrix calculation in `predict()` inflates runtime #236

Open Max-Bladen opened 2 years ago

Max-Bladen commented 2 years ago

🐞 Describe the bug:

Referring to the following formula used to make predictions via a (s)PLS model:

$Y = XW(P'W)C + Y_r$

In the academic literature, the $W(P'W)C$ matrix is referred to as $W^ $. Looking at the following lines ([593-603]() from predict()), we can see the calculation of $W^$ occurring three times, within a loop. This matrix is going to be constant as these are outputs of the model used to make predictions. Hence, this repeated calculation is extremely inefficient and bloats runtime significantly.

Ypred = lapply(1 : ncomp[i], function(x){concat.newdata[[i]] %*% Wmat[, 1:x] %*% solve(t(Pmat[, 1:x]) %*% Wmat[, 1:x]) %*% t(Cmat)[1:x, ]})
Ypred = sapply(Ypred, function(x){x*sigma.Y + means.Y}, simplify = "array")

Y.hat[[i]] = array(Ypred, c(nrow(newdata[[i]]), ncol(Y), ncomp[i])) # in case one observation and only one Y, we need array() to keep it an array with a third dimension being ncomp

t.pred[[i]] = concat.newdata[[i]] %*% Wmat %*% solve(t(Pmat) %*% Wmat)
t.pred[[i]] = matrix(data = sapply(1:ncol(t.pred[[i]]),
                                   function(x) {t.pred[[i]][, x] * apply(variatesX[[i]], 2,
                                                                         function(y){(norm(y, type = "2"))^2})[x]}), nrow = nrow(concat.newdata[[i]]), ncol = ncol(t.pred[[i]]))

B.hat[[i]] = sapply(1 : ncomp[i], function(x){Wmat[, 1:x] %*% solve(t(Pmat[, 1:x]) %*% Wmat[, 1:x]) %*% t(Cmat)[1:x, ]}, simplify = "array")

🤔 Expected behavior: To not have the same calculation performed three times within a loop.


💡 Possible solution: A simple, initial fix will be implemented first. By having the calculation occur prior to this loop and referring to a defined object should hopefully reduce run time.

A more complicated solution would be to adjust the output of our various functions (eg. spls(), splsda() etc) to return these crucial data frames. This allows users to understand the model, use it for their own purposes as well as decrease runtime by reducing the quantity of matrix calculations.


Credit: This issue was report by @psalguerog. I greatly appreciate you bringing this to my attention.

Max-Bladen commented 1 year ago

Here's a brief summary of the work I've done so far. Regarding specifically the inflated runtime.

The old set of code can be seen above. This was suspected to be inefficient due to the fact that $W* $ was seemingly calculated three times.

I adjusted the code as can be seen below:

W.star <- Wmat %*% solve(t(Pmat) %*% Wmat)

B.hat[[i]] = sapply(1 : ncomp[i], function(x){matrix(W.star[, 1:x], ncol=x) %*% t(Cmat)[1:x, ]}, simplify = "array")

# Prediction Y.hat, B.hat and t.pred
Ypred = lapply(1 : ncomp[i], function(x){concat.newdata[[i]] %*% B.hat[[i]][,,x]})
Ypred = sapply(Ypred, function(x){x*sigma.Y + means.Y}, simplify = "array")

Y.hat[[i]] = array(Ypred, c(nrow(newdata[[i]]), ncol(Y), ncomp[i])) # in case one observation and only one Y, we need array() to keep it an array with a third dimension being ncomp

t.pred[[i]] = concat.newdata[[i]] %*% W.star
t.pred[[i]] = matrix(data = sapply(1:ncol(t.pred[[i]]),
                                   function(x) {t.pred[[i]][, x] * apply(variatesX[[i]], 2,
                                                                         function(y){(norm(y, type = "2"))^2})[x]}), nrow = nrow(concat.newdata[[i]]), ncol = ncol(t.pred[[i]]))

The main differences are:

Now seeing as this reduces the number of required matrix multiplcations, I assumed this would reduce run time. Using:

X <- liver.toxicity$gene
Y <- liver.toxicity$clinic

and randomly generating 100 samples with the same number of columns as X (each using a unique normal distribution) to use as testing data. I ran the default predict() function and the adjusted predict() function 5000 times and evaluated their runtimes. Also, for peace of mind, every iteration, the predictions by the two forms of predict() were validated to be equal (to 10 significant figures). Histograms of the runtimes can be seen below:

image

image

So you can see that the runtime was not improved at all, if anything made worse. This was the same using subsets of the liver.toxicity$gene data for training and testing. This result seems counterintuitive