desh2608 / gss

A simple package for Guided source separation (GSS)
MIT License
105 stars 13 forks source link

Fix einsum path in cp.einsum #19

Closed desh2608 closed 1 year ago

desh2608 commented 1 year ago

Since we roughly know the shapes of the tensors, we can fix the einsum_path instead of computing the optimal path each time.

Using the optimal path gives the following FLOP speedup in the CACG _log_pdf:

  Complete contraction:  cfdt,cfde,cfe,cfge,cfgt->cft
         Naive scaling:  6
     Optimized scaling:  5
      Naive FLOP count:  1.480e+10
  Optimized FLOP count:  9.699e+8
   Theoretical speedup:  1.526e+1
  Largest intermediate:  6.040e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   4              0         cfe,cfde->cfed              cfdt,cfge,cfgt,cfed->cft
   5              0        cfed,cfge->cfdg                   cfdt,cfgt,cfdg->cft
   5              0        cfdg,cfdt->cfgt                        cfgt,cfgt->cft
   4              0         cfgt,cfgt->cft                              cft->cft

The main speedup will come from the CACG log_pdf computation (~1.5x). Note that the original implementation also had the optimal path contraction, but this optimal path was being computed at every iteration, which is very time-consuming (see here). Instead of doing this, we now fix the optimal path.