willwerscheid / flashier

A faster and angrier package for EBMF.
https://willwerscheid.github.io/flashier/
Other
10 stars 12 forks source link

slow initialization #102

Closed stephens999 closed 6 months ago

stephens999 commented 1 year ago

when var_type = 2 initializing factors can be slower than expected. This appears to be due to slow computation of R2 in calc.R2, and particularly this line:

EFsq <- premult.nmode.prod.r1(Z, lowrank.expand(EF)^2, r1.ones(flash), n)

in residuals_and_R2.R

Here are some comparisons

library(irlba)
library(Matrix)
library(flashier)
library(fastTopics)
data(pbmc_facs)
X = Matrix(pbmc_facs$counts,sparse=TRUE)
s = irlba(X,nv=5) # get first 5 singular vectors
f2 = flash_init(X,var_type=2)
f0 = flash_init(X,var_type=0)
system.time(f.1 <- flash(X,greedy_Kmax = 5,var_type=2))
system.time(f2.2 <- flash_factors_init(f2,s))
system.time(f0.2 <- flash_factors_init(f0,s))

system.time(flashier:::calc.R2(f2.2$flash_fit))
system.time(flashier:::calc.R2(f0.2$flash_fit)
willwerscheid commented 1 year ago

Rewriting as follows seems to fix the problem:

Browse[2]> system.time(EFsq <- premult.nmode.prod.r1(Z, lowrank.expand(EF)^2, r1.ones(flash), n))
   user  system elapsed 
  0.747   0.184   0.932 
Browse[2]> system.time(EFsq <- colSums(
+     apply(EF[[n]], 1, tcrossprod) * as.vector(crossprod(EF[[-n]]))
+ ))
   user  system elapsed 
  0.015   0.002   0.017 

This will only work for matrices (not tensors) so I am leaving a TODO in there for tensors.

willwerscheid commented 1 year ago

Running your code now gives:

> system.time(f.1 <- flash(X,greedy_Kmax = 5,var_type=2))
Adding factor 1 to flash object...
Adding factor 2 to flash object...
Adding factor 3 to flash object...
Adding factor 4 to flash object...
Adding factor 5 to flash object...
Wrapping up...
Done.
Nullchecking 5 factors...
Done.
   user  system elapsed 
  3.454   0.451   3.906 
> system.time(f2.2 <- flash_factors_init(f2,s))
   user  system elapsed 
  0.033   0.001   0.035 
> system.time(f0.2 <- flash_factors_init(f0,s))
   user  system elapsed 
  0.023   0.001   0.023 
> 
> system.time(flashier:::calc.R2(f2.2$flash_fit))
   user  system elapsed 
  0.024   0.001   0.025 
> system.time(flashier:::calc.R2(f0.2$flash_fit)
+ )
   user  system elapsed 
  0.009   0.001   0.009 
willwerscheid commented 1 year ago

Please verify @stephens999 and I will close

willwerscheid commented 1 year ago

do note however that this will only help when K^2 << p (var_type = 2) or K^2 << n (var_type = 1)

stephens999 commented 1 year ago

it looks good to me (did you make a change in a branch? I wasn't sure how to verify.)

Do you know the computational complexity? Is it the same complexity for var_type=0,2?

willwerscheid commented 1 year ago

I think the complexity is K^2 min(n, p) for both. The changes are in the main branch.